Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20f1c8e3
Unverified
Commit
20f1c8e3
authored
Apr 19, 2025
by
Yubo Wang
Committed by
GitHub
Apr 19, 2025
Browse files
Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)
parent
613b197e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
9 deletions
+8
-9
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+3
-4
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+5
-5
No files found.
python/sglang/srt/layers/sampler.py
View file @
20f1c8e3
...
@@ -100,17 +100,16 @@ class Sampler(nn.Module):
...
@@ -100,17 +100,16 @@ class Sampler(nn.Module):
probs
,
sampling_info
.
min_ps
probs
,
sampling_info
.
min_ps
)
)
else
:
else
:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan
=
self
.
use_nan_detection
and
crash_on_warnings
()
batch_next_token_ids
=
top_k_top_p_sampling_from_probs
(
batch_next_token_ids
=
top_k_top_p_sampling_from_probs
(
probs
,
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
top_ps
,
filter_apply_order
=
"joint"
,
filter_apply_order
=
"joint"
,
check_nan
=
check_nan
,
)
)
if
self
.
use_nan_detection
:
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
sgl-kernel/python/sgl_kernel/sampling.py
View file @
20f1c8e3
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
from
sgl_kernel.utils
import
_to_tensor_scalar_tuple
,
get_cuda_stream
from
sgl_kernel.utils
import
_to_tensor_scalar_tuple
,
get_cuda_stream
...
@@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
...
@@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
top_p_val
:
float
,
top_p_val
:
float
,
deterministic
:
bool
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
probs
=
probs
.
float
()
maybe_top_p_arr
=
(
maybe_top_p_arr
=
(
...
@@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
...
@@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
this operator implements GPU-based rejection sampling without explicit sorting.
...
@@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
...
@@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
top_p_val
:
float
,
top_p_val
:
float
,
deterministic
:
bool
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
...
@@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
...
@@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
deterministic
:
bool
=
True
,
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
Fused GPU kernel for top-k and top-p sampling from probabilities,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment