main.py 6.5 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad's avatar
Mohammad committed
23
24
25
from megatron import get_args
from megatron.initialize import initialize_megatron

26
27
28

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
29
30
31
32
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
33
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
34
                       help='Number of finetunning epochs. Zero results in '
35
                       'evaluation only.')
Mohammad's avatar
Mohammad committed
36
37
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
38
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
39
                       help='Keep the last batch (maybe incomplete) in'
40
                       'the data loader')
Mohammad's avatar
Mohammad committed
41
42
43
44
45
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
46
47
48
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
49
                       help='Use more difficult formulation of lambada.')
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
64

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
Mostofa Patwary's avatar
Mostofa Patwary committed
86

zihanl's avatar
zihanl committed
87
88
89
90
91
92
    # parameters for the knowledgeable dialogue generation
    group.add_argument("--sample-input-file", type=str, default=None,
                       help='Get input from file instead of interactive mode, '
                       'each line is an input.')
    group.add_argument("--sample-output-file", type=str, default=None,
                       help='Output file got from --sample-input-file')
zihanl's avatar
zihanl committed
93
    group.add_argument('--prompt-file', type=str, default=None,
zihanl's avatar
zihanl committed
94
                       help='prompting file')
zihanl's avatar
zihanl committed
95
    group.add_argument('--prompt-type', type=str, default=None,
zihanl's avatar
zihanl committed
96
97
98
                       help='prompt type (knowledge or response)')
    group.add_argument('--num-prompt-examples', type=int, default=10,
                       help='number of prompt examples')
zihanl's avatar
zihanl committed
99
    group.add_argument('--guess-file', type=str, default=None,
zihanl's avatar
zihanl committed
100
                       help='datapath for generated sentences')
zihanl's avatar
zihanl committed
101
    group.add_argument('--answer-file', type=str, default=None,
zihanl's avatar
zihanl committed
102
                       help='datapath for golden sentences')
root's avatar
root committed
103
104
    group.add_argument('--out-seq-length', type=int, default=100,
                       help='output sequence length')
zihanl's avatar
zihanl committed
105
106
    group.add_argument('--api-prompt', default=False, action="store_true",
                       help='setup model api for prompting')
zihanl's avatar
zihanl committed
107
108
    group.add_argument('--megatron-api-url', type=str, default=None,
                       help='url of the megatron api')
Mostofa Patwary's avatar
Mostofa Patwary committed
109

110
111
112
113
114
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
115
    initialize_megatron(extra_args_provider=get_tasks_args)
116

Mohammad's avatar
Mohammad committed
117
    args = get_args()
Jared Casper's avatar
Jared Casper committed
118
119
120
121
122

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
        exit()

123
124
125
126
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
127
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
128
        from zeroshot_gpt.evaluate import main
129
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
130
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
131
132
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
zihanl's avatar
zihanl committed
133
    elif args.task == 'KNWL-DIALO-PROMPT':
zihanl's avatar
zihanl committed
134
        from knwl_dialo.prompt import main
zihanl's avatar
zihanl committed
135
    elif args.task == 'KNWL-DIALO-EVAL-F1':
zihanl's avatar
zihanl committed
136
        from knwl_dialo.evaluate import main
137
138
139
140
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
141
    main()