eval_utils.py 7.54 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time
Jared Casper's avatar
Jared Casper committed
20
from functools import partial
21
22
23

import torch

Neel Kant's avatar
Neel Kant committed
24
from megatron import get_args
25
from megatron import print_rank_last, is_last_rank
26
from megatron import mpu
Jared Casper's avatar
Jared Casper committed
27
from megatron.schedules import get_forward_backward_func
Mohammad's avatar
Mohammad committed
28
29
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
30
31


Mohammad's avatar
Mohammad committed
32
def accuracy_func_provider(single_dataset_provider):
33
    """Provide function that calculates accuracies."""
Mohammad's avatar
Mohammad committed
34
    args = get_args()
35
36
37
38
39

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
Mohammad's avatar
Mohammad committed
40
        dataset = single_dataset_provider(datapath)
41
        dataloader = build_data_loader(
Jared Casper's avatar
Jared Casper committed
42
            dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
43
44
45
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

Mohammad's avatar
Mohammad committed
46
    def metrics_func(model, epoch, output_predictions=False):
47
        print_rank_last('calculating metrics ...')
48
49
50
51
52
53
54
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
Mohammad's avatar
Mohammad committed
55
            output = calculate_correct_answers(name, model, dataloader,
56
57
58
59
60
61
62
63
64
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
65
66
67
68
        if is_last_rank():
            percent = float(correct) * 100.0 / float(total)
            print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                  '{:.4f} %'.format(epoch, correct, total, percent))
69

70
        if output_predictions and is_last_rank():
71
72
73
74
75
76
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func

Jared Casper's avatar
Jared Casper committed
77

Mohammad's avatar
Mohammad committed
78
def calculate_correct_answers(name, model, dataloader,
79
80
81
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
82
    args = get_args()
Jared Casper's avatar
Jared Casper committed
83
    forward_backward_func = get_forward_backward_func()
84
    start_time = time.time()
Jared Casper's avatar
Jared Casper committed
85
86
87
88
89
90
91
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
92
93
94
95
96
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
Jared Casper's avatar
Jared Casper committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

137
138
139
140
141
142
143
144
145
146
147
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
148
149
150
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
Jared Casper's avatar
Jared Casper committed
151
            actual_batch_size = len(batch['label'])
152
            # ... applying sample_multiplier if necessary
Jared Casper's avatar
Jared Casper committed
153
154
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
155

Jared Casper's avatar
Jared Casper committed
156
157
            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)
158

Jared Casper's avatar
Jared Casper committed
159
            for loss_dict in loss_dicts:
160
                if output_predictions:
Jared Casper's avatar
Jared Casper committed
161
162
163
164
165
166
167
168
169
170
171
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']


    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size
172
173

    # Reduce.
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
                        '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                            epoch, name, correct_ans, total_count,
                            percent, elapsed_time))
189

190
191
192
        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
193
    if output_predictions:
194
195
        return 0, 0, ()
    return 0, 0