main.py 7.52 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad's avatar
Mohammad committed
23
24
25
from megatron import get_args
from megatron.initialize import initialize_megatron

26
27
28

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
29
30
31
32
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
33
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
34
                       help='Number of finetunning epochs. Zero results in '
35
                       'evaluation only.')
Mohammad's avatar
Mohammad committed
36
37
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
38
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
39
                       help='Keep the last batch (maybe incomplete) in'
40
                       'the data loader')
Mohammad's avatar
Mohammad committed
41
42
43
44
45
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
46
47
48
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
49
                       help='Use more difficult formulation of lambada.')
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
64

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
Mostofa Patwary's avatar
Mostofa Patwary committed
86

zihanl's avatar
zihanl committed
87
88
89
    # finetune for controllable dialogue
    group.add_argument('--train-module', type=str, default="",
                       help='either control module or dialogue model (control or dialog)')
zihanl's avatar
zihanl committed
90
91
92
93
94
95
96
97
    group.add_argument('--train-data-path', type=str, default="",
                       help='datapath for training set')
    group.add_argument('--test-data-path', type=str, default="",
                       help='datapath for test set')
    group.add_argument('--guess-file', type=str, default="",
                       help='datapath for generated sentences')
    group.add_argument('--answer-file', type=str, default="",
                       help='datapath for golden sentences')
zihanl's avatar
zihanl committed
98
99
    group.add_argument('--max-seq-len', type=int, default=1024,
                       help='maximum sequence length')
zihanl's avatar
zihanl committed
100
    group.add_argument('--spec-toks', type=str, default=None,
zihanl's avatar
zihanl committed
101
                       help='additional special tokens')
zihanl's avatar
zihanl committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    group.add_argument('--last-turn', action='store_true',
                       help='only use last turn for control model')
    group.add_argument('--no-control-code', action='store_true',
                       help='removing control code in the training for control model')
    group.add_argument('--remove-stopwords', action='store_true',
                       help='removing stopwords when evaluating F1-score')
    group.add_argument('--add-separator', action='store_true', 
                       help='add separator between turns and add colon before generation')
    group.add_argument('--add-ctrl-code-to-dialog', action='store_true', 
                       help='add control code in the dialog modeling')
    group.add_argument('--remove-ctrl-sent', action='store_true', 
                       help='dont use control sentence in dialog modeling')


    # finetune for controllable generation
    group.add_argument('--wiki-path', type=str, default="",
                       help='data path for the wikipedia corpus')
    group.add_argument('--tokenized-path', type=str, default="",
                       help='data path for the tokenized file')
    group.add_argument('--prop', type=float, default=1.0,
                       help='Proportion of data used for training')
    group.add_argument('--max-instance', type=int, default=10000000,
                       help='Proportion of data used for training')
Mostofa Patwary's avatar
Mostofa Patwary committed
125

126
127
128
129
130
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
131
    initialize_megatron(extra_args_provider=get_tasks_args)
132

Mohammad's avatar
Mohammad committed
133
    args = get_args()
Jared Casper's avatar
Jared Casper committed
134
135
136
137
138

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
        exit()

139
140
141
142
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
143
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
144
        from zeroshot_gpt.evaluate import main
145
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
146
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
147
148
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
zihanl's avatar
zihanl committed
149
150
    elif args.task == 'control-gen':
        from control_gen.finetune import main
zihanl's avatar
zihanl committed
151
152
    elif args.task == 'dialctrl':
        from dialctrl.finetune import main
zihanl's avatar
zihanl committed
153
154
    elif args.task in ['dialctrl-eval-ppl', 'dialctrl-eval-f1']:
        from dialctrl.evaluate import main
155
156
157
158
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
159
    main()