eval_utils.py 7.03 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Evaluation utilities."""

import os
import time
Jared Casper's avatar
Jared Casper committed
7
from functools import partial
8
9
10

import torch

xingjinliang's avatar
xingjinliang committed
11
12
from megatron.training import get_args
from megatron.training import print_rank_last, is_last_rank
13
from megatron.core import mpu
Jared Casper's avatar
Jared Casper committed
14
from megatron.schedules import get_forward_backward_func
Mohammad's avatar
Mohammad committed
15
16
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
17
18


Mohammad's avatar
Mohammad committed
19
def accuracy_func_provider(single_dataset_provider):
20
    """Provide function that calculates accuracies."""
Mohammad's avatar
Mohammad committed
21
    args = get_args()
22
23
24
25
26

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
Mohammad's avatar
Mohammad committed
27
        dataset = single_dataset_provider(datapath)
28
        dataloader = build_data_loader(
Jared Casper's avatar
Jared Casper committed
29
            dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
30
31
32
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

Mohammad's avatar
Mohammad committed
33
    def metrics_func(model, epoch, output_predictions=False):
34
        print_rank_last('calculating metrics ...')
35
36
37
38
39
40
41
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
Mohammad's avatar
Mohammad committed
42
            output = calculate_correct_answers(name, model, dataloader,
43
44
45
46
47
48
49
50
51
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
52
53
54
55
        if is_last_rank():
            percent = float(correct) * 100.0 / float(total)
            print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                  '{:.4f} %'.format(epoch, correct, total, percent))
56

57
        if output_predictions and is_last_rank():
58
59
60
61
62
63
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func

Jared Casper's avatar
Jared Casper committed
64

Mohammad's avatar
Mohammad committed
65
def calculate_correct_answers(name, model, dataloader,
66
67
68
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
69
    args = get_args()
Jared Casper's avatar
Jared Casper committed
70
    forward_backward_func = get_forward_backward_func()
71
    start_time = time.time()
Jared Casper's avatar
Jared Casper committed
72
73
74
75
76
77
78
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
79
80
81
82
83
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
Jared Casper's avatar
Jared Casper committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
xingjinliang's avatar
xingjinliang committed
114
        except Exception:
Jared Casper's avatar
Jared Casper committed
115
116
117
118
119
120
121
122
123
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

124
125
126
127
128
129
130
131
132
133
134
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
135
136
137
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
Jared Casper's avatar
Jared Casper committed
138
            actual_batch_size = len(batch['label'])
139
            # ... applying sample_multiplier if necessary
Jared Casper's avatar
Jared Casper committed
140
141
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
142

Jared Casper's avatar
Jared Casper committed
143
144
            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)
145

Jared Casper's avatar
Jared Casper committed
146
            for loss_dict in loss_dicts:
147
                if output_predictions:
Jared Casper's avatar
Jared Casper committed
148
149
150
151
152
153
154
155
156
157
158
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']


    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size
159
160

    # Reduce.
161
    if mpu.is_pipeline_last_stage():
xingjinliang's avatar
xingjinliang committed
162
        unreduced = torch.tensor([correct, total], dtype=torch.long, device='cuda')
163
164
165
166
167
168
169
170
171
172
173
174
175
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
                        '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                            epoch, name, correct_ans, total_count,
                            percent, elapsed_time))
176

177
178
179
        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
180
    if output_predictions:
181
182
        return 0, 0, ()
    return 0, 0