Unverified Commit 8f4b1559 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Temporary fix invalid sample results (#668)

parent e3046ea3
......@@ -673,6 +673,16 @@ class Batch:
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
# FIXME: This is a temporary fix for the illegal token ids in sampling.
illegal_mask = (
batch_next_token_ids < 0 or batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal token ids in sampling.")
batch_next_token_ids = torch.where(
illegal_mask, torch.argmax(probs, dim=-1), batch_next_token_ids
)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
batch_next_token_ids = torch.argmax(probs, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment