Unverified Commit 28b0a62b authored by zifeitong's avatar zifeitong Committed by GitHub
Browse files

Bug: Fix min_p sampling crash when using flashinfer backend (#3207)


Co-authored-by: default avatarzhaochenyang20 <zhaochen20@outlook.com>
parent 566d61d9
...@@ -85,7 +85,7 @@ class Sampler(nn.Module): ...@@ -85,7 +85,7 @@ class Sampler(nn.Module):
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs( batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps probs, uniform_samples, sampling_info.min_ps
) )
else: else:
...@@ -97,9 +97,9 @@ class Sampler(nn.Module): ...@@ -97,9 +97,9 @@ class Sampler(nn.Module):
filter_apply_order="joint", filter_apply_order="joint",
) )
if self.use_nan_detectioin and not torch.all(success): if self.use_nan_detectioin and not torch.all(success):
logger.warning("Detected errors during sampling!") logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids) batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch": elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment