evaluate.py 7.9 KB
Newer Older
Mohammad's avatar
Mohammad committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Mohammad's avatar
Mohammad committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""GPT zero-shot evaluation."""
Mohammad's avatar
Mohammad committed
17
18
19
20
21

import math

import torch

Neel Kant's avatar
Neel Kant committed
22
from megatron import get_args
23
from megatron import print_rank_0, is_last_rank
Mohammad's avatar
Mohammad committed
24
25
26
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
27
from megatron.model import GPTModel
Jared Casper's avatar
Jared Casper committed
28
from megatron.training import get_model
29
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
Jared Casper's avatar
Jared Casper committed
30
from megatron.p2p_communication import recv_forward, send_forward
Mohammad's avatar
Mohammad committed
31
32
from tasks.finetune_utils import build_data_loader

Raul Puri's avatar
Raul Puri committed
33
from .datasets import build_dataset
Mohammad's avatar
Mohammad committed
34

35
36
37
38
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Mohammad's avatar
Mohammad committed
39
40
41
42
43

def get_model_provider(eval_metric):
    """Based on evaluation metric set the parallel-output flag and
    return the model provider."""

44
    def model_provider(pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
45
46
47
48
49
50
51
52
53
54
        """Build the model."""

        if eval_metric == 'loss':
            parallel_output = True
        elif eval_metric == 'accuracy':
            parallel_output = False
        else:
            raise NotImplementedError('output type for {} evaluation metric '
                                      'is not supported.'.format(eval_metric))

55
        print_rank_0('building GPT model ...')
56
57
        model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
                         pre_process=pre_process, post_process=post_process)
Mohammad's avatar
Mohammad committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

        return model

    return model_provider


def process_batch(batch):
    """Process batch and produce inputs for the model."""
    args = get_args()
    tokenizer = get_tokenizer()

    loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
    tokens_ = batch['text'].long().cuda().contiguous()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
80
        args.eod_mask_loss)
Mohammad's avatar
Mohammad committed
81
82
83
84
85
86
87
88
89
90
91

    return tokens, labels, attention_mask, position_ids, loss_mask


def forward_step(batch, model, eval_metric):
    """Forward step."""

    # Get the batch.
    tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
        batch)

92
93
94
95
    # Tell the model what our actual batch size will be
    args = get_args()
    args.micro_batch_size = len(labels)

Jared Casper's avatar
Jared Casper committed
96
    input_tensor = recv_forward()
Mohammad's avatar
Mohammad committed
97

98
    # Forward pass through the model.
99
100
101
102
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
    output = model(tokens, position_ids, attention_mask)
103

Jared Casper's avatar
Jared Casper committed
104
    send_forward(output)
105
106
107
108
109
110
111
112
113

    if mpu.is_pipeline_last_stage():
        # For loss, return the unreduced loss.
        if eval_metric == 'loss':
            losses = mpu.vocab_parallel_cross_entropy(
                output.contiguous().float(), labels.contiguous())
            loss = torch.sum(
                losses.view(-1) * loss_mask.contiguous().view(-1).float())
            return loss
Mohammad's avatar
Mohammad committed
114

115
116
117
118
119
120
121
        # For accuracy, return the number of correctly predicted samples.
        if eval_metric == 'accuracy':
            outputs = torch.argmax(output, -1)
            correct = (outputs == labels).float()
            correct[(1 - loss_mask).bool()] = 1
            correct = correct.prod(-1)
            return correct.sum()
Mohammad's avatar
Mohammad committed
122

123
124
125
        raise NotImplementedError('forward method for evaluation metric {} '
                                  'is not implemented.'.format(eval_metric))
    return None
Mohammad's avatar
Mohammad committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


def evaluate(data_loader, model, eval_metric):
    """Evaluation."""
    args = get_args()

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_output = 0.0
    with torch.no_grad():
        # For all the batches in the dataset.
        for iteration, batch in enumerate(data_loader):
            if iteration % args.log_interval == 0:
                print_rank_0('> working on iteration: {}'.format(iteration))
            # Forward evaluation.
            output = forward_step(batch, model, eval_metric)

            # Reduce across processes.
145
146
147
            if mpu.is_pipeline_last_stage():
                torch.distributed.all_reduce(output,
                                             group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
148

149
                total_output += output
Mohammad's avatar
Mohammad committed
150
151
152
153
154
155
156
157
158
159
160

    return total_output


def evaluate_and_print_results(task, data_loader, model, eval_metric):
    """Evaluate and print results on screen."""

    # Evaluate and get results.
    output = evaluate(data_loader, model, eval_metric)

    string = ' validation results on {} | '.format(task)
161
162
163
164
165
166
167
168
169
170
171
172
    if is_last_rank():
        if eval_metric == 'loss':
            num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
            num_original_tokens = data_loader.dataset.num_original_tokens
            val_loss = output / (num_tokenized_tokens - 1)
            ppl = math.exp(min(20, val_loss))
            token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
            adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
            string += 'avg loss: {:.4E} | '.format(val_loss)
            string += 'ppl: {:.4E} | '.format(ppl)
            string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
            string += 'token ratio: {} |'.format(token_ratio)
Mohammad's avatar
Mohammad committed
173

174
175
176
177
178
179
180
181
182
183
        elif eval_metric == 'accuracy':
            num_examples = len(data_loader.dataset)
            acc = output / num_examples
            string += 'number correct: {:.4E} | '.format(output)
            string += 'total examples: {:.4E} | '.format(num_examples)
            string += 'avg accuracy: {:.4E}'.format(acc)

        else:
            raise NotImplementedError('evaluation method for {} metric is not '
                                      'implemented yet.'.format(eval_metric))
Mohammad's avatar
Mohammad committed
184

185
186
187
188
        length = len(string) + 1
        print('-' * length)
        print(string)
        print('-' * length)
Mohammad's avatar
Mohammad committed
189
190
191
192
193
194


def main():
    """Main program."""
    args = get_args()

Jared Casper's avatar
Jared Casper committed
195
196
197
198
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

Mohammad's avatar
Mohammad committed
199
200
201
202
203
204
205
206
207
    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(
            args.task))

    # Set up model and load checkpoint.
208
    model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
Mohammad's avatar
Mohammad committed
209
210
211
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

Jared Casper's avatar
Jared Casper committed
212
213
214
    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

Mohammad's avatar
Mohammad committed
215
216
    # Data stuff.
    dataset = build_dataset(args.task)
217
    dataloader = build_data_loader(dataset, args.micro_batch_size,
Mohammad's avatar
Mohammad committed
218
219
220
221
222
223
                                   args.num_workers, drop_last=False)

    # Run evaluation.
    evaluate_and_print_results(args.task, dataloader, model, eval_metric)

    print_rank_0('done :-)')