Commit aaa7aa32 authored by zihanl's avatar zihanl
Browse files

remove finetune part

parent a87777bf
......@@ -2,116 +2,10 @@
"""Model evaluation"""
from megatron import get_args
from megatron import get_timers
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.checkpointing import load_checkpoint
from tasks.finetune_utils import build_data_loader
from tasks.knwl_dialo.data import build_test_dataset
from tasks.knwl_dialo.data import build_test_dataset_for_prompting
from tasks.knwl_dialo.finetune import model_provider
from tasks.knwl_dialo.finetune import process_batch
from tasks.knwl_dialo.finetune import loss_func
from tasks.knwl_dialo.finetune import forward_step
from tasks.knwl_dialo.metrics import F1Metric
from tqdm import tqdm
def test_dataset_provider():
"""Build the test dataset"""
args = get_args()
print_rank_0('> building the test dataset for %s module ...' \
% args.module)
if args.prompt_type != "":
print_rank_0('> evaluating ppl for prompting')
test_ds = build_test_dataset_for_prompting(
test_data_path=args.test_data_path,
prompt_file=args.prompt_file,
module=args.module,
max_seq_len=args.seq_length,
num_prompt_examples=args.num_prompt_examples,
three_turns=args.three_turns,
dynamic_prompt=args.dynamic_prompt)
else:
print_rank_0('> evaluating ppl for finetuning')
test_ds = build_test_dataset(
test_data_path=args.test_data_path,
module=args.module,
max_seq_len=args.seq_length,
last_turn=args.last_turn,
no_control_code=args.no_control_code,
add_separator=args.add_separator,
add_ctrl_code_to_dialog=args.add_ctrl_code_to_dialog,
remove_ctrl_sent=args.remove_ctrl_sent)
print_rank_0("> finished creating the test dataset for %s module ..." \
% args.module)
print_rank_0('> test set size: %d' % len(test_ds))
args.eval_iters = len(test_ds) // args.global_batch_size
print_rank_0('> evaluation iteration: %d' % args.eval_iters)
return test_ds
def _build_test_iterator(test_dataset, task_collate_fn=None):
"""Test dataloader."""
args = get_args()
print_rank_0('building test dataloader ...')
# Test loader
test_dataloader = build_data_loader(test_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
test_iterator = test_dataloader.__iter__()
return test_iterator
def evaluate_ppl(test_dataset_provider, model_provider, forward_step):
"""Evaluating perplexity"""
args = get_args()
timers = get_timers()
# test dataloader.
timers('test dataset/dataloder').start()
test_dataset = test_dataset_provider()
test_iterator = _build_test_iterator(test_dataset)
timers('test dataset/dataloder').stop()
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
timers('pretrained checkpoint').start()
if args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
original_rng = args.no_load_rng
args.no_load_rng = True
iteration = load_checkpoint(model, None, None)
args.load = original_load
args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['test dataset/dataloder', 'model and optimizer',
'pretrained checkpoint'])
print_rank_0('evaluating ...')
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
test_iterator, model,
iteration, False)
print_rank_0('done :-)')
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
......@@ -146,9 +40,5 @@ def evaluate_f1(guess_file, answer_file):
def main():
args = get_args()
if 'PPL' in args.task:
evaluate_ppl(test_dataset_provider, model_provider, forward_step)
elif 'F1' in args.task:
evaluate_f1(args.guess_file, args.answer_file)
evaluate_f1(args.guess_file, args.answer_file)
......@@ -12,37 +12,6 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""
Build attention masks and position id for left to right model.
Different from the existing get_ltor_masks_and_position_ids function,
we add padding to the input sequences to make sure their lengths are the same.
"""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones(
(micro_batch_size, seq_length, seq_length), device=data.device)).view(
micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
def switch(val1, val2, boolean):
"""Return either val1 or val2 depending on boolean"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment