"audioset/vggish_postprocess.py" did not exist on "a29310729739c7cf41e1c429fe80f8dbbe01b11e"
evaluate.py 7.75 KB
Newer Older
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Mohammad's avatar
Mohammad committed
2

3
"""GPT zero-shot evaluation."""
Mohammad's avatar
Mohammad committed
4
5
6
7
8

import math

import torch

Neel Kant's avatar
Neel Kant committed
9
from megatron import get_args
10
from megatron import print_rank_0, is_last_rank
Mohammad's avatar
Mohammad committed
11
from megatron import get_tokenizer
12
from megatron.core import parallel_state, tensor_parallel
Mohammad's avatar
Mohammad committed
13
from megatron.checkpointing import load_checkpoint
14
from megatron.model import GPTModel
Jared Casper's avatar
Jared Casper committed
15
from megatron.training import get_model
16
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
liangjing's avatar
v1  
liangjing committed
17
18
from megatron.core.pipeline_parallel.p2p_communication import recv_forward, send_forward
from megatron.arguments import core_transformer_config_from_args
Mohammad's avatar
Mohammad committed
19
20
from tasks.finetune_utils import build_data_loader

Raul Puri's avatar
Raul Puri committed
21
from .datasets import build_dataset
Mohammad's avatar
Mohammad committed
22

23
24
25
26
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
Mohammad's avatar
Mohammad committed
27
28
29
30
31

def get_model_provider(eval_metric):
    """Based on evaluation metric set the parallel-output flag and
    return the model provider."""

32
    def model_provider(pre_process=True, post_process=True):
Mohammad's avatar
Mohammad committed
33
34
        """Build the model."""

liangjing's avatar
v1  
liangjing committed
35
36
        config = core_transformer_config_from_args(get_args())

Mohammad's avatar
Mohammad committed
37
38
39
40
41
42
43
44
        if eval_metric == 'loss':
            parallel_output = True
        elif eval_metric == 'accuracy':
            parallel_output = False
        else:
            raise NotImplementedError('output type for {} evaluation metric '
                                      'is not supported.'.format(eval_metric))

45
        print_rank_0('building GPT model ...')
liangjing's avatar
v1  
liangjing committed
46
        model = GPTModel(config, num_tokentypes=0, parallel_output=parallel_output,
47
                         pre_process=pre_process, post_process=post_process)
Mohammad's avatar
Mohammad committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

        return model

    return model_provider


def process_batch(batch):
    """Process batch and produce inputs for the model."""
    args = get_args()
    tokenizer = get_tokenizer()

    loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
    tokens_ = batch['text'].long().cuda().contiguous()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
70
        args.eod_mask_loss)
Mohammad's avatar
Mohammad committed
71
72
73
74

    return tokens, labels, attention_mask, position_ids, loss_mask


liangjing's avatar
v1  
liangjing committed
75
def forward_step(batch, model, eval_metric, config):
Mohammad's avatar
Mohammad committed
76
77
78
79
80
81
    """Forward step."""

    # Get the batch.
    tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
        batch)

82
83
84
85
    # Tell the model what our actual batch size will be
    args = get_args()
    args.micro_batch_size = len(labels)

liangjing's avatar
v1  
liangjing committed
86
87
    tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
    input_tensor = recv_forward(tensor_shape, config)
Mohammad's avatar
Mohammad committed
88

89
    # Forward pass through the model.
90
91
92
93
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
    output = model(tokens, position_ids, attention_mask)
94

liangjing's avatar
v1  
liangjing committed
95
    send_forward(output, config)
96

97
    if parallel_state.is_pipeline_last_stage():
98
99
        # For loss, return the unreduced loss.
        if eval_metric == 'loss':
100
            losses = tensor_parallel.vocab_parallel_cross_entropy(
101
102
103
104
                output.contiguous().float(), labels.contiguous())
            loss = torch.sum(
                losses.view(-1) * loss_mask.contiguous().view(-1).float())
            return loss
Mohammad's avatar
Mohammad committed
105

106
107
108
109
110
111
112
        # For accuracy, return the number of correctly predicted samples.
        if eval_metric == 'accuracy':
            outputs = torch.argmax(output, -1)
            correct = (outputs == labels).float()
            correct[(1 - loss_mask).bool()] = 1
            correct = correct.prod(-1)
            return correct.sum()
Mohammad's avatar
Mohammad committed
113

114
115
116
        raise NotImplementedError('forward method for evaluation metric {} '
                                  'is not implemented.'.format(eval_metric))
    return None
Mohammad's avatar
Mohammad committed
117
118
119
120
121


def evaluate(data_loader, model, eval_metric):
    """Evaluation."""
    args = get_args()
liangjing's avatar
v1  
liangjing committed
122
123
    config = core_transformer_config_from_args(args)
    
Mohammad's avatar
Mohammad committed
124
125
126
127
128
129
130
131
132
133
    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_output = 0.0
    with torch.no_grad():
        # For all the batches in the dataset.
        for iteration, batch in enumerate(data_loader):
            if iteration % args.log_interval == 0:
                print_rank_0('> working on iteration: {}'.format(iteration))
            # Forward evaluation.
liangjing's avatar
v1  
liangjing committed
134
            output = forward_step(batch, model, eval_metric, config)
Mohammad's avatar
Mohammad committed
135
136

            # Reduce across processes.
137
            if parallel_state.is_pipeline_last_stage():
138
                torch.distributed.all_reduce(output,
139
                                             group=parallel_state.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
140

141
                total_output += output
Mohammad's avatar
Mohammad committed
142
143
144
145
146
147
148
149
150
151
152

    return total_output


def evaluate_and_print_results(task, data_loader, model, eval_metric):
    """Evaluate and print results on screen."""

    # Evaluate and get results.
    output = evaluate(data_loader, model, eval_metric)

    string = ' validation results on {} | '.format(task)
153
154
155
156
157
158
159
160
161
162
163
164
    if is_last_rank():
        if eval_metric == 'loss':
            num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
            num_original_tokens = data_loader.dataset.num_original_tokens
            val_loss = output / (num_tokenized_tokens - 1)
            ppl = math.exp(min(20, val_loss))
            token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
            adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
            string += 'avg loss: {:.4E} | '.format(val_loss)
            string += 'ppl: {:.4E} | '.format(ppl)
            string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
            string += 'token ratio: {} |'.format(token_ratio)
Mohammad's avatar
Mohammad committed
165

166
167
168
169
170
171
172
173
174
175
        elif eval_metric == 'accuracy':
            num_examples = len(data_loader.dataset)
            acc = output / num_examples
            string += 'number correct: {:.4E} | '.format(output)
            string += 'total examples: {:.4E} | '.format(num_examples)
            string += 'avg accuracy: {:.4E}'.format(acc)

        else:
            raise NotImplementedError('evaluation method for {} metric is not '
                                      'implemented yet.'.format(eval_metric))
Mohammad's avatar
Mohammad committed
176

177
178
179
180
        length = len(string) + 1
        print('-' * length)
        print(string)
        print('-' * length)
Mohammad's avatar
Mohammad committed
181
182
183
184
185
186


def main():
    """Main program."""
    args = get_args()

Jared Casper's avatar
Jared Casper committed
187
188
189
190
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

Mohammad's avatar
Mohammad committed
191
192
193
194
195
196
197
198
199
    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(
            args.task))

    # Set up model and load checkpoint.
200
    model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
Mohammad's avatar
Mohammad committed
201
202
203
    if args.load is not None:
        _ = load_checkpoint(model, None, None)

Jared Casper's avatar
Jared Casper committed
204
205
206
    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

Mohammad's avatar
Mohammad committed
207
208
    # Data stuff.
    dataset = build_dataset(args.task)
209
    dataloader = build_data_loader(dataset, args.micro_batch_size,
Mohammad's avatar
Mohammad committed
210
211
212
213
214
215
                                   args.num_workers, drop_last=False)

    # Run evaluation.
    evaluate_and_print_results(args.task, dataloader, model, eval_metric)

    print_rank_0('done :-)')