Unverified Commit 2c11f9c2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

chore: upgrade sgl-kernel 0.0.9.post2 (#5540)

parent a6f892e5
......@@ -47,7 +47,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.0.9.post1",
"sgl-kernel==0.0.9.post2",
"flashinfer_python==0.2.3",
"torch==2.5.1",
"torchvision==0.20.1",
......
......@@ -93,25 +93,21 @@ class Sampler(nn.Module):
).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
probs, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if self.use_nan_detection and not torch.all(success):
if self.use_nan_detection:
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
......
......@@ -20,7 +20,7 @@ pip install --upgrade pip
# Install flashinfer and sgl-kernel
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir
pip install sgl-kernel==0.0.9.post1 --no-cache-dir
pip install sgl-kernel==0.0.9.post2 --no-cache-dir
# Install the main package
pip install -e "python[all]" --find-links ${FLASHINFER_REPO}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment