eval_utils.py 7.21 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time
Jared Casper's avatar
Jared Casper committed
20
from functools import partial
21
22
23

import torch

Neel Kant's avatar
Neel Kant committed
24
from megatron import get_args
25
from megatron import print_rank_last, is_last_rank
26
from megatron import mpu
Jared Casper's avatar
Jared Casper committed
27
from megatron.schedules import get_forward_backward_func
Mohammad's avatar
Mohammad committed
28
29
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
30
31


Mohammad's avatar
Mohammad committed
32
def accuracy_func_provider(single_dataset_provider):
33
    """Provide function that calculates accuracies."""
Mohammad's avatar
Mohammad committed
34
    args = get_args()
35
36
37
38
39

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
Mohammad's avatar
Mohammad committed
40
        dataset = single_dataset_provider(datapath)
41
        dataloader = build_data_loader(
Jared Casper's avatar
Jared Casper committed
42
            dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
43
44
45
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

Mohammad's avatar
Mohammad committed
46
    def metrics_func(model, epoch, output_predictions=False):
47
        print_rank_last('calculating metrics ...')
48
49
50
51
52
53
54
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
Mohammad's avatar
Mohammad committed
55
            output = calculate_correct_answers(name, model, dataloader,
56
57
58
59
60
61
62
63
64
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
65
66
67
68
        if is_last_rank():
            percent = float(correct) * 100.0 / float(total)
            print(' >> |epoch: {}| overall: correct / total = {} / {} = '
                  '{:.4f} %'.format(epoch, correct, total, percent))
69

70
        if output_predictions and is_last_rank():
71
72
73
74
75
76
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func

Jared Casper's avatar
Jared Casper committed
77

Mohammad's avatar
Mohammad committed
78
def calculate_correct_answers(name, model, dataloader,
79
80
81
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
82
    args = get_args()
Jared Casper's avatar
Jared Casper committed
83
    forward_backward_func = get_forward_backward_func()
84
    start_time = time.time()
Jared Casper's avatar
Jared Casper committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

132
133
134
135
136
137
138
139
140
141
142
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
143
144
145
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
Jared Casper's avatar
Jared Casper committed
146
            actual_batch_size = len(batch['label'])
147
            # ... applying sample_multiplier if necessary
Jared Casper's avatar
Jared Casper committed
148
149
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
150

Jared Casper's avatar
Jared Casper committed
151
152
            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)
153

Jared Casper's avatar
Jared Casper committed
154
            for loss_dict in loss_dicts:
155
                if output_predictions:
Jared Casper's avatar
Jared Casper committed
156
157
158
159
160
161
162
163
164
165
166
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']


    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size
167
168

    # Reduce.
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
                        '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                            epoch, name, correct_ans, total_count,
                            percent, elapsed_time))
184

185
186
187
        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
188
    if output_predictions:
189
190
        return 0, 0, ()
    return 0, 0