"vscode:/vscode.git/clone" did not exist on "b3035112a113bcf609e2bf79f71f33f35863f3e3"
main.py 6.78 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad's avatar
Mohammad committed
23
24
25
from megatron import get_args
from megatron.initialize import initialize_megatron

26
27
28

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
29
30
31
32
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
33
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
34
                       help='Number of finetunning epochs. Zero results in '
35
                       'evaluation only.')
Mohammad's avatar
Mohammad committed
36
37
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
38
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
39
                       help='Keep the last batch (maybe incomplete) in'
40
                       'the data loader')
Mohammad's avatar
Mohammad committed
41
42
43
44
45
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
46
47
48
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
49
                       help='Use more difficult formulation of lambada.')
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
64

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
Mostofa Patwary's avatar
Mostofa Patwary committed
86

zihanl's avatar
zihanl committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    # parameters for the knowledgeable dialogue generation
    group.add_argument("--sample-input-file", type=str, default=None,
                       help='Get input from file instead of interactive mode, '
                       'each line is an input.')
    group.add_argument("--sample-output-file", type=str, default=None,
                       help='Output file got from --sample-input-file')
    group.add_argument('--prompt-file', type=str, default="",
                       help='prompting file')
    group.add_argument('--prompt-type', type=str, default="",
                       help='prompt type (knowledge or response)')
    group.add_argument('--num-prompt-examples', type=int, default=10,
                       help='number of prompt examples')
    group.add_argument('--dynamic-prompt', action='store_true', default=False,
                       help='using different prompts for different test samples')
    group.add_argument('--module', type=str, default="",
                       help='either knowledge generation (knowledge) or response generation (response)')
zihanl's avatar
zihanl committed
103
104
105
106
107
    group.add_argument('--guess-file', type=str, default="",
                       help='datapath for generated sentences')
    group.add_argument('--answer-file', type=str, default="",
                       help='datapath for golden sentences')
    group.add_argument('--spec-toks', type=str, default=None,
zihanl's avatar
zihanl committed
108
                       help='additional special tokens')
root's avatar
root committed
109
110
    group.add_argument('--out-seq-length', type=int, default=100,
                       help='output sequence length')
Mostofa Patwary's avatar
Mostofa Patwary committed
111

112
113
114
115
116
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
117
    initialize_megatron(extra_args_provider=get_tasks_args)
118

Mohammad's avatar
Mohammad committed
119
    args = get_args()
Jared Casper's avatar
Jared Casper committed
120
121
122
123
124

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
        exit()

125
126
127
128
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
129
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
130
        from zeroshot_gpt.evaluate import main
131
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
132
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
133
134
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
zihanl's avatar
zihanl committed
135
    elif args.task == 'KNWL-DIALO-PROMPT':
zihanl's avatar
zihanl committed
136
        from knwl_dialo.prompt import main
zihanl's avatar
zihanl committed
137
    elif args.task in ['KNWL-DIALO-FINETUNE', 'KNWL-DIALO-GEN']:
zihanl's avatar
zihanl committed
138
        from knwl_dialo.finetune import main
zihanl's avatar
zihanl committed
139
    elif args.task == 'KNWL-DIALO-EVAL-F1':
zihanl's avatar
zihanl committed
140
        from knwl_dialo.evaluate import main
141
142
143
144
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
145
    main()