Unverified Commit 39c57317 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Revert "Temporary fix invalid sample results" (#673)

parent 9592a1f3
...@@ -673,16 +673,6 @@ class Batch: ...@@ -673,16 +673,6 @@ class Batch:
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps probs, uniform_samples, self.top_ks, self.top_ps
) )
# FIXME: This is a temporary fix for the illegal token ids in sampling.
illegal_mask = (
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal token ids in sampling.")
batch_next_token_ids = torch.where(
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
)
except RuntimeError as e: except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}") warnings.warn(f"Ignore errors in sampling: {e}")
batch_next_token_ids = torch.argmax(probs, dim=-1) batch_next_token_ids = torch.argmax(probs, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment