Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
734daedd
Unverified
Commit
734daedd
authored
Jan 31, 2025
by
Byron Hsu
Committed by
GitHub
Jan 31, 2025
Browse files
[fix] Clamp logprob with dtype min to prevent `-inf` (#3224)
parent
3ee62235
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+5
-2
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
.../sampling/penaltylib/test_srt_endpoint_with_penalizers.py
+3
-4
No files found.
python/sglang/srt/layers/sampler.py
View file @
734daedd
...
...
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
# clamp to avoid -inf
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
)
.
clamp
(
min
=
torch
.
finfo
(
probs
.
dtype
).
min
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
...
...
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
sampling_info
.
need_min_p_sampling
,
)
if
return_logprob
:
# clamp to avoid -inf
logprobs
=
torch
.
log
(
top_p_normalize_probs_torch
(
probs
,
sampling_info
.
top_ps
)
)
)
.
clamp
(
min
=
torch
.
finfo
(
probs
.
dtype
).
min
)
else
:
raise
ValueError
(
f
"Invalid sampling backend:
{
global_server_args_dict
[
'sampling_backend'
]
}
"
...
...
test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py
View file @
734daedd
...
...
@@ -36,7 +36,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
def
run_decode
(
self
,
return_logprob
=
True
,
top_logprobs_num
=
3
,
top_logprobs_num
=
5
,
return_text
=
True
,
n
=
1
,
**
sampling_params
,
...
...
@@ -58,8 +58,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
assert
response
.
status_code
==
200
,
"Request failed: "
+
response
.
text
def
test_default_values
(
self
):
self
.
run_decode
()
...
...
@@ -112,4 +111,4 @@ class TestBatchPenalizerE2E(unittest.TestCase):
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
(
verbosity
=
3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment