Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70562969
Unverified
Commit
70562969
authored
Nov 01, 2025
by
0xNullPath
Committed by
GitHub
Nov 01, 2025
Browse files
[Bug] OOM (Out-of-Memory) errors for extreme testing scenarios (min_tokens=2) (#11757)
Signed-off-by:
Yan Lu
<
luyan@nvidia.com
>
parent
b57dc169
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
27 deletions
+62
-27
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
+6
-8
python/sglang/srt/sampling/penaltylib/min_new_tokens.py
python/sglang/srt/sampling/penaltylib/min_new_tokens.py
+7
-8
python/sglang/srt/sampling/penaltylib/orchestrator.py
python/sglang/srt/sampling/penaltylib/orchestrator.py
+43
-3
python/sglang/srt/sampling/penaltylib/presence_penalty.py
python/sglang/srt/sampling/penaltylib/presence_penalty.py
+6
-8
No files found.
python/sglang/srt/sampling/penaltylib/frequency_penalty.py
View file @
70562969
import
torch
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
(
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
BatchedPenalizerOrchestrator
,
_BatchedPenalizer
,
)
class
BatchedFrequencyPenalizer
(
_BatchedPenalizer
):
class
BatchedFrequencyPenalizer
(
_BatchedPenalizer
):
...
@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
...
@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
Frequency penalizer penalizes tokens based on their frequency in the output.
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
"""
def
__init__
(
self
,
orchestrator
:
BatchedPenalizerOrchestrator
):
self
.
orchestrator
=
orchestrator
self
.
_is_prepared
=
False
def
_is_required
(
self
)
->
bool
:
def
_is_required
(
self
)
->
bool
:
return
any
(
return
any
(
req
.
sampling_params
.
frequency_penalty
!=
0.0
req
.
sampling_params
.
frequency_penalty
!=
0.0
...
@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
...
@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
[
self
.
cumulated_frequency_penalties
,
their
.
cumulated_frequency_penalties
],
[
self
.
cumulated_frequency_penalties
,
their
.
cumulated_frequency_penalties
],
dim
=
0
,
dim
=
0
,
)
)
def
_teardown
(
self
)
->
None
:
for
name
in
(
"frequency_penalties"
,
"cumulated_frequency_penalties"
):
if
hasattr
(
self
,
name
):
delattr
(
self
,
name
)
python/sglang/srt/sampling/penaltylib/min_new_tokens.py
View file @
70562969
import
torch
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
(
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
BatchedPenalizerOrchestrator
,
_BatchedPenalizer
,
)
class
BatchedMinNewTokensPenalizer
(
_BatchedPenalizer
):
class
BatchedMinNewTokensPenalizer
(
_BatchedPenalizer
):
...
@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
...
@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
Min new tokens penalizer penalizes tokens based on the length of the output.
Min new tokens penalizer penalizes tokens based on the length of the output.
"""
"""
def
__init__
(
self
,
orchestrator
:
BatchedPenalizerOrchestrator
):
self
.
orchestrator
=
orchestrator
self
.
_is_prepared
=
False
def
_is_required
(
self
)
->
bool
:
def
_is_required
(
self
)
->
bool
:
return
any
(
return
any
(
req
.
sampling_params
.
min_new_tokens
>
0
for
req
in
self
.
orchestrator
.
reqs
()
req
.
sampling_params
.
min_new_tokens
>
0
for
req
in
self
.
orchestrator
.
reqs
()
...
@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
...
@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
self
.
len_output_tokens
=
torch
.
cat
(
self
.
len_output_tokens
=
torch
.
cat
(
[
self
.
len_output_tokens
,
their
.
len_output_tokens
],
dim
=
0
[
self
.
len_output_tokens
,
their
.
len_output_tokens
],
dim
=
0
)
)
# Explicit resource cleanup to aid GC and free CUDA memory promptly
def
_teardown
(
self
)
->
None
:
for
name
in
(
"min_new_tokens"
,
"stop_token_penalties"
,
"len_output_tokens"
):
if
hasattr
(
self
,
name
):
delattr
(
self
,
name
)
python/sglang/srt/sampling/penaltylib/orchestrator.py
View file @
70562969
...
@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
...
@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
return
return
if
len
(
keep_indices
)
==
0
:
if
len
(
keep_indices
)
==
0
:
self
.
is_required
=
False
# No requests left in the batch, fully release orchestrator resources
for
penalizer
in
self
.
penalizers
.
values
():
self
.
release
()
penalizer
.
teardown
()
return
return
is_required
=
False
is_required
=
False
...
@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
...
@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
penalizer
.
teardown
()
penalizer
.
teardown
()
self
.
is_required
=
is_required
self
.
is_required
=
is_required
# Resource management helpers
def
release
(
self
)
->
None
:
"""Release all penalizers and break references so GC can reclaim promptly."""
for
penalizer
in
self
.
penalizers
.
values
():
penalizer
.
teardown
()
self
.
penalizers
.
clear
()
# Break reference to ScheduleBatch
self
.
_batch_ref
=
None
self
.
is_required
=
False
# Context manager support
def
__enter__
(
self
)
->
"BatchedPenalizerOrchestrator"
:
return
self
def
__exit__
(
self
,
exc_type
,
exc
,
tb
)
->
None
:
self
.
release
()
def
merge
(
self
,
their
:
"BatchedPenalizerOrchestrator"
):
def
merge
(
self
,
their
:
"BatchedPenalizerOrchestrator"
):
"""
"""
Merge the penalizers of another orchestrator into this one.
Merge the penalizers of another orchestrator into this one.
...
@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
...
@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer.
An abstract class for a batched penalizer.
"""
"""
def
__init__
(
self
,
orchestrator
:
BatchedPenalizerOrchestrator
):
self
.
_orchestrator_ref
:
weakref
.
ReferenceType
[
BatchedPenalizerOrchestrator
]
=
(
weakref
.
ref
(
orchestrator
)
)
self
.
_is_prepared
=
False
@
property
def
orchestrator
(
self
)
->
BatchedPenalizerOrchestrator
:
orch
:
Optional
[
BatchedPenalizerOrchestrator
]
=
self
.
_orchestrator_ref
()
# This should never happen, but we need to handle it gracefully
if
orch
is
None
:
raise
RuntimeError
(
"BatchedPenalizerOrchestrator has been garbage-collected"
)
return
orch
def
is_prepared
(
self
)
->
bool
:
def
is_prepared
(
self
)
->
bool
:
return
self
.
_is_prepared
return
self
.
_is_prepared
...
@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
...
@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
return
False
return
False
def
teardown
(
self
):
def
teardown
(
self
):
self
.
_teardown
()
self
.
_is_prepared
=
False
self
.
_is_prepared
=
False
def
cumulate_output_tokens
(
self
,
output_ids
:
torch
.
Tensor
):
def
cumulate_output_tokens
(
self
,
output_ids
:
torch
.
Tensor
):
...
@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
...
@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
Merge the penalizer with another penalizer.
Merge the penalizer with another penalizer.
"""
"""
pass
pass
@
abc
.
abstractmethod
def
_teardown
(
self
):
"""
Teardown the penalizer.
"""
pass
python/sglang/srt/sampling/penaltylib/presence_penalty.py
View file @
70562969
import
torch
import
torch
from
sglang.srt.sampling.penaltylib.orchestrator
import
(
from
sglang.srt.sampling.penaltylib.orchestrator
import
_BatchedPenalizer
BatchedPenalizerOrchestrator
,
_BatchedPenalizer
,
)
class
BatchedPresencePenalizer
(
_BatchedPenalizer
):
class
BatchedPresencePenalizer
(
_BatchedPenalizer
):
...
@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
...
@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
Presence penalizer penalizes tokens based on their presence in the output.
Presence penalizer penalizes tokens based on their presence in the output.
"""
"""
def
__init__
(
self
,
orchestrator
:
BatchedPenalizerOrchestrator
):
self
.
orchestrator
=
orchestrator
self
.
_is_prepared
=
False
def
_is_required
(
self
)
->
bool
:
def
_is_required
(
self
)
->
bool
:
return
any
(
return
any
(
req
.
sampling_params
.
presence_penalty
!=
0.0
req
.
sampling_params
.
presence_penalty
!=
0.0
...
@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
...
@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
[
self
.
cumulated_presence_penalties
,
their
.
cumulated_presence_penalties
],
[
self
.
cumulated_presence_penalties
,
their
.
cumulated_presence_penalties
],
dim
=
0
,
dim
=
0
,
)
)
def
_teardown
(
self
)
->
None
:
for
name
in
(
"presence_penalties"
,
"cumulated_presence_penalties"
):
if
hasattr
(
self
,
name
):
delattr
(
self
,
name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment