"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "b5ef7c26dc2eda7bb8335277e1d62face0c24f26"
Unverified Commit 0ac94c36 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fallback when sampling failed (#678)

parent 2b4c6462
...@@ -668,18 +668,17 @@ class Batch: ...@@ -668,18 +668,17 @@ class Batch:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device) uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
batch_next_token_ids, _ = top_k_top_p_sampling_from_probs( batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps probs, uniform_samples, self.top_ks, self.top_ps
) )
# FIXME: this is a temporary fix for the illegal token ids if torch.any(~success):
illegal_mask = torch.logical_or( warnings.warn("Sampling failed, fallback to top_k=1 strategy")
batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
)
if torch.any(illegal_mask):
warnings.warn("Illegal sampled token ids")
probs = probs.masked_fill(torch.isnan(probs), 0.0) probs = probs.masked_fill(torch.isnan(probs), 0.0)
batch_next_token_ids = torch.argmax(probs, dim=-1) argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)
if has_regex: if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment