evaluation.py 3.01 KB
Newer Older
mandoxzhang's avatar
mandoxzhang committed
1
import math
2
3
import os

mandoxzhang's avatar
mandoxzhang committed
4
import torch
5
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
mandoxzhang's avatar
mandoxzhang committed
6
from tqdm import tqdm
7
8
from utils.global_vars import get_tensorboard_writer, get_timers

mandoxzhang's avatar
mandoxzhang committed
9

10
def evaluate(model, args, logger, global_step, criterion):
mandoxzhang's avatar
mandoxzhang committed
11
12
13
    evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True)
    start_shard = 0

14
    model.eval()
mandoxzhang's avatar
mandoxzhang committed
15
16
17
18
19
20
21
22
23
24
    timers = get_timers()
    eval_step = 0
    eval_loss = 0
    cur_loss = 0
    world_size = torch.distributed.get_world_size()

    with torch.no_grad():

        for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))):

25
            timers('eval_shard_time').start()
mandoxzhang's avatar
mandoxzhang committed
26
27
28
29

            dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard)
            # evaluate_dataset_provider.prefetch_shard(shard + 1)
            if torch.distributed.get_rank() == 0:
30
31
32
33
                iterator_data = tqdm(enumerate(dataset_iterator),
                                     total=(total_length // args.eval_micro_batch_size_per_gpu // world_size),
                                     colour='MAGENTA',
                                     smoothing=1)
mandoxzhang's avatar
mandoxzhang committed
34
35
            else:
                iterator_data = enumerate(dataset_iterator)
36
37

            for step, batch_data in iterator_data:    #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1):
mandoxzhang's avatar
mandoxzhang committed
38
39
40
41
42
43
44
45
46

                # batch_data = pretrain_dataset_provider.get_batch(batch_index)
                eval_step += 1
                input_ids = batch_data[0].cuda()
                attention_mask = batch_data[1].cuda()
                token_type_ids = batch_data[2].cuda()
                mlm_label = batch_data[3].cuda()
                # nsp_label = batch_data[5].cuda()

47
                output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
48
49

                loss = criterion(output.logits, mlm_label)    #prediction_scores
mandoxzhang's avatar
mandoxzhang committed
50
51
52
53
54
55
56
57
58
59
60
61
                evaluate_dataset_provider.prefetch_batch()

                eval_loss += loss.float().item()

            cur_loss = eval_loss / eval_step
            elapsed_time = timers("eval_shard_time").elapsed()
            elapsed_time_per_iteration = elapsed_time / eval_step
            ppl = math.exp(cur_loss)

            if args.wandb and torch.distributed.get_rank() == 0:
                tensorboard_log = get_tensorboard_writer()
                tensorboard_log.log_eval({
62
63
64
65
                    'loss': cur_loss,
                    'ppl': ppl,
                    'mins_batch': elapsed_time_per_iteration
                }, global_step)
mandoxzhang's avatar
mandoxzhang committed
66
67
68
69
70
71
72
73
74

            eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \
                            f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}'

            logger.info(eval_log_str)
            logger.info('-' * 100)
            logger.info('')

    evaluate_dataset_provider.release_shard()
75
    model.train()
76
    return cur_loss