Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20f1c8e3
"examples/pytorch/appnp/train.py" did not exist on "2758c249559cbd6cb7c582018a453d324913df5f"
Unverified
Commit
20f1c8e3
authored
Apr 19, 2025
by
Yubo Wang
Committed by
GitHub
Apr 19, 2025
Browse files
Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)
parent
613b197e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
9 deletions
+8
-9
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+3
-4
sgl-kernel/python/sgl_kernel/sampling.py
sgl-kernel/python/sgl_kernel/sampling.py
+5
-5
No files found.
python/sglang/srt/layers/sampler.py
View file @
20f1c8e3
...
...
@@ -100,17 +100,16 @@ class Sampler(nn.Module):
probs
,
sampling_info
.
min_ps
)
else
:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan
=
self
.
use_nan_detection
and
crash_on_warnings
()
batch_next_token_ids
=
top_k_top_p_sampling_from_probs
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
filter_apply_order
=
"joint"
,
check_nan
=
check_nan
,
)
if
self
.
use_nan_detection
:
logger
.
warning
(
"Detected errors during sampling!"
)
batch_next_token_ids
=
torch
.
zeros_like
(
batch_next_token_ids
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"pytorch"
:
# A slower fallback implementation with torch native operations.
batch_next_token_ids
=
top_k_top_p_min_p_sampling_from_probs_torch
(
...
...
sgl-kernel/python/sgl_kernel/sampling.py
View file @
20f1c8e3
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
torch
from
sgl_kernel.utils
import
_to_tensor_scalar_tuple
,
get_cuda_stream
...
...
@@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
top_p_val
:
float
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_p_arr
=
(
...
...
@@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
...
...
@@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
top_p_val
:
float
,
deterministic
:
bool
,
generator
:
Optional
[
torch
.
Generator
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
with
probs
.
device
as
device
:
probs
=
probs
.
float
()
maybe_top_k_arr
=
maybe_top_k_arr
.
int
()
if
maybe_top_k_arr
is
not
None
else
None
...
...
@@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
deterministic
:
bool
=
True
,
generator
:
Optional
[
torch
.
Generator
]
=
None
,
check_nan
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
)
->
torch
.
Tensor
:
r
"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment