evaluate.py 5.21 KB
Newer Older
zihanl's avatar
zihanl committed
1
2
3
4
5
6
7
8
9

from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
zihanl's avatar
zihanl committed
10
11
12
13
14
15
16
from tasks.knwl_dialo.data import build_test_dataset
from tasks.knwl_dialo.data import build_test_dataset_for_prompting
from tasks.knwl_dialo.finetune import model_provider 
from tasks.knwl_dialo.finetune import process_batch 
from tasks.knwl_dialo.finetune import loss_func 
from tasks.knwl_dialo.finetune import forward_step 
from tasks.knwl_dialo.metrics import F1Metric
zihanl's avatar
zihanl committed
17
18
19
20
21
22
23
24
from tqdm import tqdm

def test_dataset_provider():
    """Build the test dataset for dialog/control module"""
    args = get_args()
    print_rank_0('> building the test dataset for %s module ...' \
                    % args.train_module)

zihanl's avatar
zihanl committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    if args.eval_prompting:
        print_rank_0('> evaluating ppl for prompting')
        test_ds = build_test_dataset_for_prompting(
            test_data_path=args.test_data_path,
            prompt_file=args.prompt_file,
            train_module=args.train_module,
            max_seq_len=args.max_seq_len,
            num_prompt_examples=args.num_prompt_examples,
            three_turns=args.three_turns,
            dynamic_prompt=args.dynamic_prompt)

    else:
        test_ds = build_test_dataset(
            test_data_path=args.test_data_path,
            train_module=args.train_module,
            max_seq_len=args.max_seq_len,
            last_turn=args.last_turn,
            no_control_code=args.no_control_code,
            add_separator=args.add_separator,
            add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
            remove_ctrl_sent=args.remove_ctrl_sent)
zihanl's avatar
zihanl committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    print_rank_0("> finished creating the test dataset for %s module ..." \
                    % args.train_module)

    print_rank_0('> test set size: %d' % len(test_ds))
    args.eval_iters = len(test_ds) // args.global_batch_size
    print_rank_0('> evaluation iteration: %d' % args.eval_iters)

    return test_ds


def _build_test_iterator(test_dataset, task_collate_fn=None):
    """Test dataloader."""
    args = get_args()

    print_rank_0('building test dataloader ...')
    # Test loader
    test_dataloader = build_data_loader(test_dataset, args.micro_batch_size,
                                        args.num_workers, not args.keep_last,
                                        task_collate_fn)
    test_iterator = test_dataloader.__iter__()
    return test_iterator


def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
    args = get_args()
    timers = get_timers()

    # test dataloader.
    timers('test dataset/dataloder').start()
    test_dataset = test_dataset_provider()
    test_iterator = _build_test_iterator(test_dataset)
    timers('test dataset/dataloder').stop()

    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()

    timers('pretrained checkpoint').start()
    if args.pretrained_checkpoint is not None:
        original_load = args.load
        args.load = args.pretrained_checkpoint
        original_rng = args.no_load_rng
        args.no_load_rng = True
        iteration = load_checkpoint(model, None, None)
        args.load = original_load
        args.no_load_rng = original_rng
        # This is critical when only model is loaded. We should make sure
        # main parameters are also updated.
        optimizer.reload_model_params()
    timers('pretrained checkpoint').stop()

    # Print setup timing.
    print_rank_0('done with setups ...')
    timers.log(['test dataset/dataloder', 'model and optimizer', 
                'pretrained checkpoint'])
    
    print_rank_0('evaluating ...')
    prefix = 'iteration {}'.format(iteration)
    evaluate_and_print_results(prefix, forward_step, 
                               test_iterator, model,
                               iteration, False)
    
    print_rank_0('done :-)')


zihanl's avatar
zihanl committed
112
def evaluate_f1(guess_file, answer_file):
zihanl's avatar
zihanl committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

    guess_list = []
    print_rank_0('reading %s' % guess_file)
    with open(guess_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if "<|endoftext|>" in line:
                line = line.replace("<|endoftext|>", "")
            guess_list.append(line)

    answer_list = []
    print_rank_0('reading %s' % answer_file)
    with open(answer_file, "r") as f:
        for i, line in enumerate(tqdm(f)):
            line = line.strip()
            if line == "no_passages_used":
                line = ""
            answer_list.append(line)

    assert len(guess_list) == len(answer_list), \
        "lengths of guess and answer are different!"

zihanl's avatar
zihanl committed
135
    precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
zihanl's avatar
zihanl committed
136
137
138
139
140
141
142
    print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))

    print_rank_0('done :-)')


def main():
    args = get_args()
zihanl's avatar
zihanl committed
143
    
zihanl's avatar
zihanl committed
144
    if 'PPL' in args.task: 
zihanl's avatar
zihanl committed
145
146
        evaluate_ppl(test_dataset_provider, model_provider, forward_step)
    
zihanl's avatar
zihanl committed
147
    elif 'F1' in args.task:
zihanl's avatar
zihanl committed
148
        evaluate_f1(args.guess_file, args.answer_file)
zihanl's avatar
zihanl committed
149