main.py 4.29 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

xingjinliang's avatar
xingjinliang committed
10
11
from megatron.training import get_args
from megatron.training.initialize import initialize_megatron
Mohammad's avatar
Mohammad committed
12

13
14
15

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
16
17
18
19
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
20
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
21
                       help='Number of finetunning epochs. Zero results in '
22
23
                       'evaluation only.')
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
24
                       help='Keep the last batch (maybe incomplete) in'
25
                       'the data loader')
Mohammad's avatar
Mohammad committed
26
27
28
29
30
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
31
32
33
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
34
                       help='Use more difficult formulation of lambada.')
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
49

Mostofa Patwary's avatar
Mostofa Patwary committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
Mostofa Patwary's avatar
Mostofa Patwary committed
71
72


73
74
75
76
77
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
78
    initialize_megatron(extra_args_provider=get_tasks_args)
79

Mohammad's avatar
Mohammad committed
80
    args = get_args()
Jared Casper's avatar
Jared Casper committed
81
82
83
84
85

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
        exit()

86
87
88
89
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
90
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
91
        from zeroshot_gpt.evaluate import main
92
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
93
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
94
95
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
96
97
98
99
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
100
    main()