main.py 4.42 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9

"""Main tasks functionality."""

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))

Mohammad's avatar
Mohammad committed
10
11
12
from megatron import get_args
from megatron.initialize import initialize_megatron

13
14
15

def get_tasks_args(parser):
    """Provide extra arguments required for tasks."""
Mohammad's avatar
Mohammad committed
16
17
18
19
    group = parser.add_argument_group(title='tasks')

    group.add_argument('--task', type=str, required=True,
                       help='Task name.')
Mohammad's avatar
Mohammad committed
20
    group.add_argument('--epochs', type=int, default=None,
Mohammad's avatar
Mohammad committed
21
                       help='Number of finetunning epochs. Zero results in '
22
                       'evaluation only.')
Mohammad's avatar
Mohammad committed
23
24
    group.add_argument('--pretrained-checkpoint', type=str, default=None,
                       help='Pretrained checkpoint used for finetunning.')
25
    group.add_argument('--keep-last', action='store_true',
Mohammad's avatar
Mohammad committed
26
                       help='Keep the last batch (maybe incomplete) in'
27
                       'the data loader')
Mohammad's avatar
Mohammad committed
28
29
30
31
32
    group.add_argument('--train-data', nargs='+', default=None,
                       help='Whitespace separated paths or corpora names '
                       'for training.')
    group.add_argument('--valid-data', nargs='*', default=None,
                       help='path(s) to the validation data.')
Mohammad's avatar
Mohammad committed
33
34
35
    group.add_argument('--overlapping-eval', type=int, default=32,
                       help='Sliding window for overlapping evaluation.')
    group.add_argument('--strict-lambada', action='store_true',
Neel Kant's avatar
Neel Kant committed
36
                       help='Use more difficult formulation of lambada.')
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    # Retriever args
    group.add_argument('--qa-data-dev', type=str, default=None,
                       help='Path to the QA dataset dev file.')
    group.add_argument('--qa-data-test', type=str, default=None,
                       help='Path to the QA dataset test file.')

    # Faiss arguments for retriever
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
    group.add_argument('--faiss-match', type=str, default='string', \
                        choices=['regex', 'string'], help="Answer matching '\
                        'logic type")
    group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                       help='Number of blocks to use as top-k during retrieval')
Mohammad's avatar
Mohammad committed
51

Mostofa Patwary's avatar
Mostofa Patwary committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    # finetune for retriever
    group.add_argument('--eval-micro-batch-size', type=int, default=None,
                       help='Eval Batch size per model instance (local batch '
                            'size). Global batch size is local batch size '
                            'times data parallel size.')
    group.add_argument('--train-with-neg', action='store_true',
                       help='Whether to use negative examples during model '
                        'training')
    group.add_argument('--train-hard-neg', type=int, default=0,
                       help='Number of hard negative exmaples to use during '
                        'training')


    # parameters for Av.rank validation method
    # Following options/arguments have been taken directly from DPR codebase
    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                        help='Av.rank validation: how many hard negatives to'
                        ' take from each question pool')
    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                        help='Av.rank validation: how many other negatives to'
                        ' take from each question pool')
Mostofa Patwary's avatar
Mostofa Patwary committed
73
74


75
76
77
78
79
    return parser


if __name__ == '__main__':

Mohammad's avatar
Mohammad committed
80
    initialize_megatron(extra_args_provider=get_tasks_args)
81

Mohammad's avatar
Mohammad committed
82
    args = get_args()
Jared Casper's avatar
Jared Casper committed
83
84
85
86
87

    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
        exit()

88
89
90
91
    if args.task == 'RACE':
        from race.finetune import main
    elif args.task in ['MNLI', 'QQP']:
        from glue.finetune import main
Mohammad's avatar
Mohammad committed
92
    elif args.task in ['LAMBADA', 'WIKITEXT103']:
93
        from zeroshot_gpt.evaluate import main
94
    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
95
        from orqa.evaluate_orqa import main
Mostofa Patwary's avatar
Mostofa Patwary committed
96
97
    elif args.task in ['RET-FINETUNE-NQ']:
        from orqa.supervised.finetune import main
98
99
100
101
    else:
        raise NotImplementedError('Task {} is not implemented.'.format(
            args.task))

Mohammad's avatar
Mohammad committed
102
    main()