Commit 716a3243 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_generate' into 'main'

lm evaluation

See merge request ADLR/megatron-lm!262
parents 7a5768ac e5ec27d7
...@@ -190,6 +190,37 @@ def generate_samples_input_from_file(model): ...@@ -190,6 +190,37 @@ def generate_samples_input_from_file(model):
raw_text = None raw_text = None
context_count += 1 context_count += 1
# We added this function to support the tasks evaluation such as squad
# and drop in the https://github.com/EleutherAI/lm-evaluation-harness
# codebase. The lm-evaluation-harness code can now call this function
# similar to their current generate function call used for gpt style models.
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
# Generate samples for lm evaluation
# NEED TO THINK ABOUT eos token
args = get_args()
tokenizer = get_tokenizer()
raw_text_len = len(context)
model.eval()
context_tokens = tokenizer.tokenize(context)
args.out_seq_length = max_gen_length + len(context_tokens)
args.eos_id = eos_token_id
with torch.no_grad():
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
if counter == args.out_seq_length:
break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
return trim_decode_tokens
def generate_samples_interactive(model, print_frequency=24): def generate_samples_interactive(model, print_frequency=24):
...@@ -438,7 +469,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -438,7 +469,13 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
eos_id = tokenizer.eod
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0 counter = 0
org_context_length = context_length org_context_length = context_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment