finetune.py 1.59 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

3
4
"""Race."""

Neel Kant's avatar
Neel Kant committed
5
6
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
7
from megatron import get_tokenizer
Jared Casper's avatar
Jared Casper committed
8
from megatron.model.multiple_choice import MultipleChoice
9
10
11
12
13
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset


Mohammad's avatar
Mohammad committed
14
def train_valid_datasets_provider():
15
    """Provide train and validation datasets."""
Mohammad's avatar
Mohammad committed
16
17
    args = get_args()
    tokenizer = get_tokenizer()
18
19

    train_dataset = RaceDataset('training', args.train_data,
Mohammad's avatar
Mohammad committed
20
                                tokenizer, args.seq_length)
21
    valid_dataset = RaceDataset('validation', args.valid_data,
Mohammad's avatar
Mohammad committed
22
                                tokenizer, args.seq_length)
23
24
25
26

    return train_dataset, valid_dataset


Jared Casper's avatar
Jared Casper committed
27
def model_provider(pre_process=True, post_process=True):
28
29
30
    """Build the model."""

    print_rank_0('building multichoice model for RACE ...')
Jared Casper's avatar
Jared Casper committed
31
32
33
    model = MultipleChoice(num_tokentypes=2,
                           pre_process=pre_process,
                           post_process=post_process)
34
35

    return model
36
37


Mohammad's avatar
Mohammad committed
38
def metrics_func_provider():
39
    """Privde metrics callback function."""
Mohammad's avatar
Mohammad committed
40
41
    args = get_args()
    tokenizer = get_tokenizer()
42

Mohammad's avatar
Mohammad committed
43
    def single_dataset_provider(datapath):
44
        name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
Mohammad's avatar
Mohammad committed
45
        return RaceDataset(name, [datapath], tokenizer, args.seq_length)
46

Mohammad's avatar
Mohammad committed
47
    return accuracy_func_provider(single_dataset_provider)
48
49


Mohammad's avatar
Mohammad committed
50
def main():
51

Mohammad's avatar
Mohammad committed
52
    finetune(train_valid_datasets_provider, model_provider,
53
             end_of_epoch_callback_provider=metrics_func_provider)