Commit 7a3bd55f authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Fix decode function not using top_p during iterative decoding

parent 847abe65
......@@ -173,7 +173,7 @@ def decode(
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1
):
next_token = sample(logits, top_k=top_k, temperature=temperature)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment