Unverified Commit 70562969 authored by 0xNullPath's avatar 0xNullPath Committed by GitHub
Browse files

[Bug] OOM (Out-of-Memory) errors for extreme testing scenarios (min_tokens=2) (#11757)


Signed-off-by: default avatarYan Lu <luyan@nvidia.com>
parent b57dc169
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
class BatchedFrequencyPenalizer(_BatchedPenalizer):
......@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
Frequency penalizer penalizes tokens based on their frequency in the output.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
req.sampling_params.frequency_penalty != 0.0
......@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
dim=0,
)
def _teardown(self) -> None:
for name in ("frequency_penalties", "cumulated_frequency_penalties"):
if hasattr(self, name):
delattr(self, name)
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
......@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
Min new tokens penalizer penalizes tokens based on the length of the output.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
......@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
self.len_output_tokens = torch.cat(
[self.len_output_tokens, their.len_output_tokens], dim=0
)
# Explicit resource cleanup to aid GC and free CUDA memory promptly
def _teardown(self) -> None:
for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
if hasattr(self, name):
delattr(self, name)
......@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
return
if len(keep_indices) == 0:
self.is_required = False
for penalizer in self.penalizers.values():
penalizer.teardown()
# No requests left in the batch, fully release orchestrator resources
self.release()
return
is_required = False
......@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
penalizer.teardown()
self.is_required = is_required
# Resource management helpers
def release(self) -> None:
"""Release all penalizers and break references so GC can reclaim promptly."""
for penalizer in self.penalizers.values():
penalizer.teardown()
self.penalizers.clear()
# Break reference to ScheduleBatch
self._batch_ref = None
self.is_required = False
# Context manager support
def __enter__(self) -> "BatchedPenalizerOrchestrator":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.release()
def merge(self, their: "BatchedPenalizerOrchestrator"):
"""
Merge the penalizers of another orchestrator into this one.
......@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
weakref.ref(orchestrator)
)
self._is_prepared = False
@property
def orchestrator(self) -> BatchedPenalizerOrchestrator:
orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
# This should never happen, but we need to handle it gracefully
if orch is None:
raise RuntimeError(
"BatchedPenalizerOrchestrator has been garbage-collected"
)
return orch
def is_prepared(self) -> bool:
return self._is_prepared
......@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
return False
def teardown(self):
self._teardown()
self._is_prepared = False
def cumulate_output_tokens(self, output_ids: torch.Tensor):
......@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
Merge the penalizer with another penalizer.
"""
pass
@abc.abstractmethod
def _teardown(self):
"""
Teardown the penalizer.
"""
pass
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
)
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
class BatchedPresencePenalizer(_BatchedPenalizer):
......@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
Presence penalizer penalizes tokens based on their presence in the output.
"""
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator
self._is_prepared = False
def _is_required(self) -> bool:
return any(
req.sampling_params.presence_penalty != 0.0
......@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
dim=0,
)
def _teardown(self) -> None:
for name in ("presence_penalties", "cumulated_presence_penalties"):
if hasattr(self, name):
delattr(self, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment