main.py 5.25 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad's avatar
Mohammad committed
23
24
25
from megatron import get_args
from megatron.initialize import initialize_megatron

26
27
28

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
29
30
31
32
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
33
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
34
                       help='Number of finetunning epochs. Zero results in '
35
                       'evaluation only.')
Mohammad's avatar
Mohammad committed
36
37
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
38
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
39
                       help='Keep the last batch (maybe incomplete) in'
40
                       'the data loader')
Mohammad's avatar
Mohammad committed
41
42
43
44
45
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
46
47
48
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
49
                       help='Use more difficult formulation of lambada.')
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
64

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    #group.add_argument("--val-av-rank-start-epoch", type=int, default=10000,
    #                    help="Av.rank validation: the epoch from which to enable this validation")
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
    #group.add_argument("--val-av-rank-bsz", type=int, default=128,
    #                    help="Av.rank validation: batch size to process passages")
    #group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
    #                    help="Av.rank validation: max num of questions")
 
 
94
95
96
97
98
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
99
    initialize_megatron(extra_args_provider=get_tasks_args)
100

Mohammad's avatar
Mohammad committed
101
    args = get_args()
102
103
104
105
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
106
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
107
        from zeroshot_gpt.evaluate import main
108
109
    elif args.task in ['ICT-ZEROSHOT-NQ']:
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
110
111
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
112
113
114
115
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
116
    main()