Commit 5c2ce593 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed comments

parent 44bfcb32
...@@ -207,13 +207,14 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id): ...@@ -207,13 +207,14 @@ def generate_samples_eval(model, context, max_gen_length, eos_token_id):
with torch.no_grad(): with torch.no_grad():
token_stream = get_token_stream(model, [context_tokens]) token_stream = get_token_stream(model, [context_tokens])
for counter, decode_tokens in enumerate(token_stream): for counter, decode_tokens in enumerate(token_stream):
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
if counter == args.out_seq_length: if counter == args.out_seq_length:
break break
decode_tokens, _ = decode_tokens
decode_tokens = decode_tokens[0].cpu().numpy().tolist()
trim_decode_tokens = tokenizer.detokenize(
decode_tokens)[raw_text_len:]
return trim_decode_tokens return trim_decode_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment