Commit 045959cb authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

added this function for evaluation

parent f32a638d
...@@ -190,6 +190,41 @@ def generate_samples_input_from_file(model): ...@@ -190,6 +190,41 @@ def generate_samples_input_from_file(model):
raw_text = None raw_text = None
context_count += 1 context_count += 1
def generate_samples_eval(model, context, max_gen_length, eos_token_id, do_sample):
# Generate samples for lm evaluation
# NEED TO THINK ABOUT eos token
args = get_args()
tokenizer = get_tokenizer()
raw_text_len = len(context)
model.eval()
context_tokens = tokenizer.tokenize(context)
args.out_seq_length = max_gen_length + len(context_tokens)
args.recompute = True #set this default value
args.eos_id = eos_token_id
if not do_sample:
args.greedy = True
else:
# set similar to huggngface
args.top_p = 1.0
args.temperature = 1.0
args.top_k = 50
with torch.no_grad():
token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
if counter == args.out_seq_length:
break
return trim_decode_tokens
def generate_samples_interactive(model, print_frequency=24): def generate_samples_interactive(model, print_frequency=24):
...@@ -438,6 +473,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -438,6 +473,10 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod eos_id = tokenizer.eod
counter = 0 counter = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment