finetune.py 1.62 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

3
4
"""Race."""

Neel Kant's avatar
Neel Kant committed
5
6
from megatron import get_args
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
7
from megatron import get_tokenizer
8
from megatron import mpu
Jared Casper's avatar
Jared Casper committed
9
from megatron.model.multiple_choice import MultipleChoice
10
11
12
13
14
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset


Mohammad's avatar
Mohammad committed
15
def train_valid_datasets_provider():
16
    """Provide train and validation datasets."""
Mohammad's avatar
Mohammad committed
17
18
    args = get_args()
    tokenizer = get_tokenizer()
19
20

    train_dataset = RaceDataset('training', args.train_data,
Mohammad's avatar
Mohammad committed
21
                                tokenizer, args.seq_length)
22
    valid_dataset = RaceDataset('validation', args.valid_data,
Mohammad's avatar
Mohammad committed
23
                                tokenizer, args.seq_length)
24
25
26
27

    return train_dataset, valid_dataset


Jared Casper's avatar
Jared Casper committed
28
def model_provider(pre_process=True, post_process=True):
29
30
31
    """Build the model."""

    print_rank_0('building multichoice model for RACE ...')
Jared Casper's avatar
Jared Casper committed
32
33
34
    model = MultipleChoice(num_tokentypes=2,
                           pre_process=pre_process,
                           post_process=post_process)
35
36

    return model
37
38


Mohammad's avatar
Mohammad committed
39
def metrics_func_provider():
40
    """Privde metrics callback function."""
Mohammad's avatar
Mohammad committed
41
42
    args = get_args()
    tokenizer = get_tokenizer()
43

Mohammad's avatar
Mohammad committed
44
    def single_dataset_provider(datapath):
45
        name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
Mohammad's avatar
Mohammad committed
46
        return RaceDataset(name, [datapath], tokenizer, args.seq_length)
47

Mohammad's avatar
Mohammad committed
48
    return accuracy_func_provider(single_dataset_provider)
49
50


Mohammad's avatar
Mohammad committed
51
def main():
52

Mohammad's avatar
Mohammad committed
53
    finetune(train_valid_datasets_provider, model_provider,
54
             end_of_epoch_callback_provider=metrics_func_provider)