finetune.py 1.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

3
4
"""Race."""

xingjinliang's avatar
xingjinliang committed
5
6
7
8
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_tokenizer
from megatron.legacy.model.multiple_choice import MultipleChoice
9
10
11
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
xingjinliang's avatar
xingjinliang committed
12
from megatron.training.arguments import core_transformer_config_from_args
13
14


Mohammad's avatar
Mohammad committed
15
def train_valid_datasets_provider():
16
    """Provide train and validation datasets."""
Mohammad's avatar
Mohammad committed
17
18
    args = get_args()
    tokenizer = get_tokenizer()
19
20

    train_dataset = RaceDataset('training', args.train_data,
Mohammad's avatar
Mohammad committed
21
                                tokenizer, args.seq_length)
22
    valid_dataset = RaceDataset('validation', args.valid_data,
Mohammad's avatar
Mohammad committed
23
                                tokenizer, args.seq_length)
24
25
26
27

    return train_dataset, valid_dataset


Jared Casper's avatar
Jared Casper committed
28
def model_provider(pre_process=True, post_process=True):
29
    """Build the model."""
liangjing's avatar
v1  
liangjing committed
30
    config = core_transformer_config_from_args(get_args())
31
    print_rank_0('building multichoice model for RACE ...')
liangjing's avatar
v1  
liangjing committed
32
33
    model = MultipleChoice(config=config,
                           num_tokentypes=2,
Jared Casper's avatar
Jared Casper committed
34
35
                           pre_process=pre_process,
                           post_process=post_process)
36
37

    return model
38
39


Mohammad's avatar
Mohammad committed
40
def metrics_func_provider():
41
    """Privde metrics callback function."""
Mohammad's avatar
Mohammad committed
42
43
    args = get_args()
    tokenizer = get_tokenizer()
44

Mohammad's avatar
Mohammad committed
45
    def single_dataset_provider(datapath):
46
        name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
Mohammad's avatar
Mohammad committed
47
        return RaceDataset(name, [datapath], tokenizer, args.seq_length)
48

Mohammad's avatar
Mohammad committed
49
    return accuracy_func_provider(single_dataset_provider)
50
51


Mohammad's avatar
Mohammad committed
52
def main():
53

Mohammad's avatar
Mohammad committed
54
    finetune(train_valid_datasets_provider, model_provider,
55
             end_of_epoch_callback_provider=metrics_func_provider)