Commit 5a6431f5 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed comments

parent 5c2ce593
......@@ -466,6 +466,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment