evaluate.py 7.49 KB
Newer Older
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

3
"""GPT zero-shot evaluation."""
Mohammad's avatar
Mohammad committed
4
5
6
7
8

import math

import torch

xingjinliang's avatar
xingjinliang committed
9
10
11
from megatron.training import get_args
from megatron.training import print_rank_0, is_last_rank
from megatron.training import get_tokenizer
12
from megatron.core import parallel_state, tensor_parallel
xingjinliang's avatar
xingjinliang committed
13
14
from megatron.training.checkpointing import load_checkpoint
from megatron.legacy.model import GPTModel
Jared Casper's avatar
Jared Casper committed
15
from megatron.training import get_model
xingjinliang's avatar
xingjinliang committed
16
from megatron.training.utils import get_ltor_masks_and_position_ids, unwrap_model
liangjing's avatar
v1  
liangjing committed
17
from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward
xingjinliang's avatar
xingjinliang committed
18
from megatron.training.arguments import core_transformer_config_from_args
Mohammad's avatar
Mohammad committed
19
20
from tasks.finetune_utils import build_data_loader

Raul Puri's avatar
Raul Puri committed
21
from .datasets import build_dataset
Mohammad's avatar
Mohammad committed
22
23
24
25
26
27


def get_model_provider(eval_metric):
    """Based on evaluation metric set the parallel-output flag and
    return the model provider."""

28
    def model_provider(pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
29
30
        """Build the model."""

liangjing's avatar
v1  
liangjing committed
31
32
        config = core_transformer_config_from_args(get_args())

Mohammad's avatar
Mohammad committed
33
34
35
36
37
38
39
40
        if eval_metric == 'loss':
            parallel_output = True
        elif eval_metric == 'accuracy':
            parallel_output = False
        else:
            raise NotImplementedError('output type for {} evaluation metric '
                                      'is not supported.'.format(eval_metric))

41
        print_rank_0('building GPT model ...')
liangjing's avatar
v1  
liangjing committed
42
        model = GPTModel(config, num_tokentypes=0, parallel_output=parallel_output,
43
                         pre_process=pre_process, post_process=post_process)
Mohammad's avatar
Mohammad committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        return model

    return model_provider


def process_batch(batch):
    """Process batch and produce inputs for the model."""
    args = get_args()
    tokenizer = get_tokenizer()

    loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
    tokens_ = batch['text'].long().cuda().contiguous()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
66
        args.eod_mask_loss)
Mohammad's avatar
Mohammad committed
67
68
69
70

    return tokens, labels, attention_mask, position_ids, loss_mask


liangjing's avatar
v1  
liangjing committed
71
def forward_step(batch, model, eval_metric, config):
Mohammad's avatar
Mohammad committed
72
73
74
75
76
77
    """Forward step."""

    # Get the batch.
    tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
        batch)

78
79
80
81
    # Tell the model what our actual batch size will be
    args = get_args()
    args.micro_batch_size = len(labels)

liangjing's avatar
v1  
liangjing committed
82
83
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    input_tensor = recv_forward(tensor_shape, config)
Mohammad's avatar
Mohammad committed
84

85
    # Forward pass through the model.
xingjinliang's avatar
xingjinliang committed
86
    unwrapped_model = unwrap_model(model)
87
88
    unwrapped_model.set_input_tensor(input_tensor)
    output = model(tokens, position_ids, attention_mask)
89

liangjing's avatar
v1  
liangjing committed
90
    send_forward(output, config)
91

92
    if parallel_state.is_pipeline_last_stage():
93
94
        # For loss, return the unreduced loss.
        if eval_metric == 'loss':
95
            losses = tensor_parallel.vocab_parallel_cross_entropy(
96
97
98
99
                output.contiguous().float(), labels.contiguous())
            loss = torch.sum(
                losses.view(-1) * loss_mask.contiguous().view(-1).float())
            return loss
Mohammad's avatar
Mohammad committed
100

101
102
103
104
105
106
107
        # For accuracy, return the number of correctly predicted samples.
        if eval_metric == 'accuracy':
            outputs = torch.argmax(output, -1)
            correct = (outputs == labels).float()
            correct[(1 - loss_mask).bool()] = 1
            correct = correct.prod(-1)
            return correct.sum()
Mohammad's avatar
Mohammad committed
108

109
110
111
        raise NotImplementedError('forward method for evaluation metric {} '
                                  'is not implemented.'.format(eval_metric))
    return None
Mohammad's avatar
Mohammad committed
112
113
114
115
116


def evaluate(data_loader, model, eval_metric):
    """Evaluation."""
    args = get_args()
liangjing's avatar
v1  
liangjing committed
117
118
    config = core_transformer_config_from_args(args)
    
Mohammad's avatar
Mohammad committed
119
120
121
122
123
124
125
126
127
128
    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_output = 0.0
    with torch.no_grad():
        # For all the batches in the dataset.
        for iteration, batch in enumerate(data_loader):
            if iteration % args.log_interval == 0:
                print_rank_0('> working on iteration: {}'.format(iteration))
            # Forward evaluation.
liangjing's avatar
v1  
liangjing committed
129
            output = forward_step(batch, model, eval_metric, config)
Mohammad's avatar
Mohammad committed
130
131

            # Reduce across processes.
132
            if parallel_state.is_pipeline_last_stage():
133
                torch.distributed.all_reduce(output,
134
                                             group=parallel_state.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
135

136
                total_output += output
Mohammad's avatar
Mohammad committed
137
138
139
140
141
142
143
144
145
146
147

    return total_output


def evaluate_and_print_results(task, data_loader, model, eval_metric):
    """Evaluate and print results on screen."""

    # Evaluate and get results.
    output = evaluate(data_loader, model, eval_metric)

    string = ' validation results on {} | '.format(task)
148
149
150
151
152
153
154
155
156
157
158
159
    if is_last_rank():
        if eval_metric == 'loss':
            num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
            num_original_tokens = data_loader.dataset.num_original_tokens
            val_loss = output / (num_tokenized_tokens - 1)
            ppl = math.exp(min(20, val_loss))
            token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
            adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
            string += 'avg loss: {:.4E} | '.format(val_loss)
            string += 'ppl: {:.4E} | '.format(ppl)
            string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
            string += 'token ratio: {} |'.format(token_ratio)
Mohammad's avatar
Mohammad committed
160

161
162
163
164
165
166
167
168
169
170
        elif eval_metric == 'accuracy':
            num_examples = len(data_loader.dataset)
            acc = output / num_examples
            string += 'number correct: {:.4E} | '.format(output)
            string += 'total examples: {:.4E} | '.format(num_examples)
            string += 'avg accuracy: {:.4E}'.format(acc)

        else:
            raise NotImplementedError('evaluation method for {} metric is not '
                                      'implemented yet.'.format(eval_metric))
Mohammad's avatar
Mohammad committed
171

172
173
174
175
        length = len(string) + 1
        print('-' * length)
        print(string)
        print('-' * length)
Mohammad's avatar
Mohammad committed
176
177
178
179
180
181


def main():
    """Main program."""
    args = get_args()

Jared Casper's avatar
Jared Casper committed
182
183
184
185
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

Mohammad's avatar
Mohammad committed
186
187
188
189
190
191
192
193
194
    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(
            args.task))

    # Set up model and load checkpoint.
195
    model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
Mohammad's avatar
Mohammad committed
196
197
198
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

Jared Casper's avatar
Jared Casper committed
199
200
201
    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

Mohammad's avatar
Mohammad committed
202
203
    # Data stuff.
    dataset = build_dataset(args.task)
204
    dataloader = build_data_loader(dataset, args.micro_batch_size,
Mohammad's avatar
Mohammad committed
205
206
207
208
209
210
                                   args.num_workers, drop_last=False)

    # Run evaluation.
    evaluate_and_print_results(args.task, dataloader, model, eval_metric)

    print_rank_0('done :-)')