evaluate.py 5.29 KB
Newer Older
zihanl's avatar
zihanl committed
1

zihanl's avatar
zihanl committed
2
3
"""Model evaluation"""

zihanl's avatar
zihanl committed
4
5
6
7
8
9
10
11
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
zihanl's avatar
zihanl committed
12
13
14
15
16
17
18
from tasks.knwl_dialo.data import build_test_dataset
from tasks.knwl_dialo.data import build_test_dataset_for_prompting
from tasks.knwl_dialo.finetune import model_provider 
from tasks.knwl_dialo.finetune import process_batch 
from tasks.knwl_dialo.finetune import loss_func 
from tasks.knwl_dialo.finetune import forward_step 
from tasks.knwl_dialo.metrics import F1Metric
zihanl's avatar
zihanl committed
19
20
21
from tqdm import tqdm

def test_dataset_provider():
zihanl's avatar
zihanl committed
22
    """Build the test dataset"""
zihanl's avatar
zihanl committed
23
24
    args = get_args()
    print_rank_0('> building the test dataset for %s module ...' \
zihanl's avatar
zihanl committed
25
                    % args.module)
zihanl's avatar
zihanl committed
26

zihanl's avatar
zihanl committed
27
    if args.prompt_type != "":
zihanl's avatar
zihanl committed
28
29
30
31
        print_rank_0('> evaluating ppl for prompting')
        test_ds = build_test_dataset_for_prompting(
            test_data_path=args.test_data_path,
            prompt_file=args.prompt_file,
zihanl's avatar
zihanl committed
32
33
            module=args.module,
            max_seq_len=args.seq_length,
zihanl's avatar
zihanl committed
34
35
36
37
38
            num_prompt_examples=args.num_prompt_examples,
            three_turns=args.three_turns,
            dynamic_prompt=args.dynamic_prompt)

    else:
zihanl's avatar
zihanl committed
39
        print_rank_0('> evaluating ppl for finetuning')
zihanl's avatar
zihanl committed
40
41
        test_ds = build_test_dataset(
            test_data_path=args.test_data_path,
zihanl's avatar
zihanl committed
42
43
            module=args.module,
            max_seq_len=args.seq_length,
zihanl's avatar
zihanl committed
44
45
46
47
48
            last_turn=args.last_turn,
            no_control_code=args.no_control_code,
            add_separator=args.add_separator,
            add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
            remove_ctrl_sent=args.remove_ctrl_sent)
zihanl's avatar
zihanl committed
49
50

    print_rank_0("> finished creating the test dataset for %s module ..." \
zihanl's avatar
zihanl committed
51
                    % args.module)
zihanl's avatar
zihanl committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    print_rank_0('> test set size: %d' % len(test_ds))
    args.eval_iters = len(test_ds) // args.global_batch_size
    print_rank_0('> evaluation iteration: %d' % args.eval_iters)

    return test_ds


def _build_test_iterator(test_dataset, task_collate_fn=None):
    """Test dataloader."""
    args = get_args()

    print_rank_0('building test dataloader ...')
    # Test loader
    test_dataloader = build_data_loader(test_dataset, args.micro_batch_size,
                                        args.num_workers, not args.keep_last,
                                        task_collate_fn)
    test_iterator = test_dataloader.__iter__()
    return test_iterator


def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
zihanl's avatar
zihanl committed
74
    """Evaluating perplexity"""
zihanl's avatar
zihanl committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    args = get_args()
    timers = get_timers()

    # test dataloader.
    timers('test dataset/dataloder').start()
    test_dataset = test_dataset_provider()
    test_iterator = _build_test_iterator(test_dataset)
    timers('test dataset/dataloder').stop()

    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()

    timers('pretrained checkpoint').start()
    if args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
        original_rng = args.no_load_rng
        args.no_load_rng = True
        iteration = load_checkpoint(model, None, None)
        args.load = original_load
        args.no_load_rng = original_rng
        # This is critical when only model is loaded. We should make sure
        # main parameters are also updated.
        optimizer.reload_model_params()
    timers('pretrained checkpoint').stop()

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['test dataset/dataloder', 'model and optimizer', 
                'pretrained checkpoint'])
    
    print_rank_0('evaluating ...')
    prefix = 'iteration {}'.format(iteration)
    evaluate_and_print_results(prefix, forward_step, 
                               test_iterator, model,
                               iteration, False)
    
    print_rank_0('done :-)')


zihanl's avatar
zihanl committed
116
def evaluate_f1(guess_file, answer_file):
zihanl's avatar
zihanl committed
117
    """Evaluating F1 Score"""
zihanl's avatar
zihanl committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    guess_list = []
    print_rank_0('reading %s' % guess_file)
    with open(guess_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if "<|endoftext|>" in line:
                line = line.replace("<|endoftext|>", "")
            guess_list.append(line)

    answer_list = []
    print_rank_0('reading %s' % answer_file)
    with open(answer_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if line == "no_passages_used":
                line = ""
            answer_list.append(line)

    assert len(guess_list) == len(answer_list), \
        "lengths of guess and answer are different!"

zihanl's avatar
zihanl committed
140
    precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
zihanl's avatar
zihanl committed
141
142
143
144
145
146
147
    print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))

    print_rank_0('done :-)')


def main():
    args = get_args()
zihanl's avatar
zihanl committed
148
    
zihanl's avatar
zihanl committed
149
    if 'PPL' in args.task: 
zihanl's avatar
zihanl committed
150
151
        evaluate_ppl(test_dataset_provider, model_provider, forward_step)
    
zihanl's avatar
zihanl committed
152
    elif 'F1' in args.task:
zihanl's avatar
zihanl committed
153
        evaluate_f1(args.guess_file, args.answer_file)
zihanl's avatar
zihanl committed
154