eval_utils.py 4.77 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time

import torch

23
from megatron import get_args, print_rank_0
24
from megatron import mpu
Mohammad's avatar
Mohammad committed
25
26
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
27
28


Mohammad's avatar
Mohammad committed
29
def accuracy_func_provider(single_dataset_provider):
30
    """Provide function that calculates accuracies."""
Mohammad's avatar
Mohammad committed
31
    args = get_args()
32
33
34
35
36

    # Build dataloaders.
    datapaths = args.valid_data
    dataloaders = []
    for datapath in datapaths:
Mohammad's avatar
Mohammad committed
37
        dataset = single_dataset_provider(datapath)
38
39
40
41
42
        dataloader = build_data_loader(
            dataset, args.batch_size, num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))

Mohammad's avatar
Mohammad committed
43
    def metrics_func(model, epoch, output_predictions=False):
44
45
46
47
48
49
50
51
        print_rank_0('calculating metrics ...')
        correct = 0
        total = 0
        if output_predictions:
            assert mpu.get_data_parallel_world_size() == 1
            named_predictions = []
            names = 'predictions'
        for name, dataloader in dataloaders:
Mohammad's avatar
Mohammad committed
52
            output = calculate_correct_answers(name, model, dataloader,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
                                               epoch, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            correct += correct_ans
            total += total_count
        percent = float(correct) * 100.0 / float(total)
        print_rank_0(' >> |epoch: {}| overall: correct / total = {} / {} = '
                     '{:.4f} %'.format(epoch, correct, total, percent))

        if output_predictions and torch.distributed.get_rank() == 0:
            assert args.load is not None
            filename = os.path.join(args.load, names + '.pt')
            torch.save(named_predictions, filename)

    return metrics_func


Mohammad's avatar
Mohammad committed
74
def calculate_correct_answers(name, model, dataloader,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                              epoch, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""

    start_time = time.time()
    model.eval()
    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # Run the model forward.
Mohammad's avatar
Mohammad committed
93
            tokens, types, labels_, attention_mask = process_batch(batch)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            logits = model(tokens, attention_mask, types)
            # Add output predictions.
            if output_predictions:
                softmaxes.extend(torch.nn.Softmax(dim=-1)(
                    logits.float()).data.cpu().numpy().tolist())
                labels.extend(labels_.data.cpu().numpy().tolist())
                ids.extend(batch['uid'].cpu().numpy().tolist())
            # Compute the correct answers.
            predicted = torch.argmax(logits, dim=-1)
            corrects = (predicted == labels_)
            # Add to the counters.
            total += labels_.size(0)
            correct += corrects.sum().item()
    model.train()

    # Reduce.
    unreduced = torch.cuda.LongTensor([correct, total])
    torch.distributed.all_reduce(unreduced,
                                 group=mpu.get_data_parallel_group())

    # Print on screen.
    correct_ans = unreduced[0].item()
    total_count = unreduced[1].item()
    percent = float(correct_ans) * 100.0 / float(total_count)
    elapsed_time = time.time() - start_time
    print_rank_0(' > |epoch: {}| metrics for {}: correct / total '
                 '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                     epoch, name, correct_ans, total_count,
                     percent, elapsed_time))

    if output_predictions:
        return correct_ans, total_count, (softmaxes, labels, ids)
    return correct_ans, total_count