Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3774f078
Unverified
Commit
3774f078
authored
Jun 19, 2025
by
Stefan He
Committed by
GitHub
Jun 19, 2025
Browse files
Multi-Stage Awake: Support Resume and Pause KV Cache and Weights separately (#7099)
parent
9179ea15
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
308 additions
and
119 deletions
+308
-119
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/constants.py
python/sglang/srt/constants.py
+3
-0
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+2
-1
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+5
-8
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+6
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+32
-16
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+6
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-1
python/sglang/srt/torch_memory_saver_adapter.py
python/sglang/srt/torch_memory_saver_adapter.py
+19
-15
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_release_memory_occupation.py
test/srt/test_release_memory_occupation.py
+225
-68
test/srt/test_verl_engine_2_gpu.py
test/srt/test_verl_engine_2_gpu.py
+2
-1
test/srt/test_verl_engine_4_gpu.py
test/srt/test_verl_engine_4_gpu.py
+2
-1
No files found.
python/pyproject.toml
View file @
3774f078
...
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
...
@@ -98,7 +98,7 @@ srt_npu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
torch_memory_saver
=
["torch_memory_saver>=0.0.
4
"]
torch_memory_saver
=
["torch_memory_saver>=0.0.
8
"]
decord
=
["decord"]
decord
=
["decord"]
test
=
[
test
=
[
"accelerate"
,
"accelerate"
,
...
...
python/sglang/srt/constants.py
0 → 100644
View file @
3774f078
# GPU Memory Types
GPU_MEMORY_TYPE_KV_CACHE
=
"kv_cache"
GPU_MEMORY_TYPE_WEIGHTS
=
"weights"
python/sglang/srt/disaggregation/decode.py
View file @
3774f078
...
@@ -31,6 +31,7 @@ import numpy as np
...
@@ -31,6 +31,7 @@ import numpy as np
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVPoll
from
sglang.srt.disaggregation.base
import
BaseKVManager
,
BaseKVReceiver
,
KVPoll
from
sglang.srt.disaggregation.utils
import
(
from
sglang.srt.disaggregation.utils
import
(
FAKE_BOOTSTRAP_HOST
,
FAKE_BOOTSTRAP_HOST
,
...
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
...
@@ -90,7 +91,7 @@ class DecodeReqToTokenPool:
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
device
=
device
self
.
device
=
device
self
.
pre_alloc_size
=
pre_alloc_size
self
.
pre_alloc_size
=
pre_alloc_size
with
memory_saver_adapter
.
region
():
with
memory_saver_adapter
.
region
(
tag
=
GPU_MEMORY_TYPE_KV_CACHE
):
self
.
req_to_token
=
torch
.
zeros
(
self
.
req_to_token
=
torch
.
zeros
(
(
size
+
pre_alloc_size
,
max_context_len
),
(
size
+
pre_alloc_size
,
max_context_len
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
...
python/sglang/srt/entrypoints/engine.py
View file @
3774f078
...
@@ -479,17 +479,15 @@ class Engine(EngineBase):
...
@@ -479,17 +479,15 @@ class Engine(EngineBase):
self
.
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
self
.
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
)
)
)
def
release_memory_occupation
(
self
):
def
release_memory_occupation
(
self
,
tags
:
Optional
[
List
[
str
]]
=
None
):
"""Release GPU occupation temporarily."""
obj
=
ReleaseMemoryOccupationReqInput
(
tags
=
tags
)
obj
=
ReleaseMemoryOccupationReqInput
()
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
release_memory_occupation
(
obj
,
None
)
self
.
tokenizer_manager
.
release_memory_occupation
(
obj
,
None
)
)
)
def
resume_memory_occupation
(
self
):
def
resume_memory_occupation
(
self
,
tags
:
Optional
[
List
[
str
]]
=
None
):
"""Resume GPU occupation."""
obj
=
ResumeMemoryOccupationReqInput
(
tags
=
tags
)
obj
=
ResumeMemoryOccupationReqInput
()
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
)
self
.
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
)
...
@@ -670,11 +668,9 @@ def _launch_subprocesses(
...
@@ -670,11 +668,9 @@ def _launch_subprocesses(
scheduler_procs
=
[]
scheduler_procs
=
[]
if
server_args
.
dp_size
==
1
:
if
server_args
.
dp_size
==
1
:
# Launch tensor parallel scheduler processes
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
enable
=
server_args
.
enable_memory_saver
)
)
scheduler_pipe_readers
=
[]
scheduler_pipe_readers
=
[]
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
...
@@ -710,6 +706,7 @@ def _launch_subprocesses(
...
@@ -710,6 +706,7 @@ def _launch_subprocesses(
writer
,
writer
,
),
),
)
)
with
memory_saver_adapter
.
configure_subprocess
():
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_procs
.
append
(
proc
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
3774f078
...
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
...
@@ -812,7 +812,9 @@ class GetWeightsByNameReqOutput:
@
dataclass
@
dataclass
class
ReleaseMemoryOccupationReqInput
:
class
ReleaseMemoryOccupationReqInput
:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
@
dataclass
...
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
...
@@ -822,7 +824,9 @@ class ReleaseMemoryOccupationReqOutput:
@
dataclass
@
dataclass
class
ResumeMemoryOccupationReqInput
:
class
ResumeMemoryOccupationReqInput
:
pass
# Optional tags to identify the memory region, which is primarily used for RL
# Currently we only support `weights` and `kv_cache`
tags
:
Optional
[
List
[
str
]]
=
None
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
3774f078
...
@@ -36,6 +36,7 @@ from torch.distributed import barrier
...
@@ -36,6 +36,7 @@ from torch.distributed import barrier
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.constrained.base_grammar_backend
import
(
from
sglang.srt.constrained.base_grammar_backend
import
(
INVALID_GRAMMAR_OBJ
,
INVALID_GRAMMAR_OBJ
,
create_grammar_backend
,
create_grammar_backend
,
...
@@ -450,8 +451,6 @@ class Scheduler(
...
@@ -450,8 +451,6 @@ class Scheduler(
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
=
threading
.
Thread
(
target
=
self
.
watchdog_thread
,
daemon
=
True
)
t
.
start
()
t
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
# Init memory saver
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
enable
=
server_args
.
enable_memory_saver
)
)
...
@@ -2227,23 +2226,40 @@ class Scheduler(
...
@@ -2227,23 +2226,40 @@ class Scheduler(
return
GetWeightsByNameReqOutput
(
parameter
)
return
GetWeightsByNameReqOutput
(
parameter
)
def
release_memory_occupation
(
self
,
recv_req
:
ReleaseMemoryOccupationReqInput
):
def
release_memory_occupation
(
self
,
recv_req
:
ReleaseMemoryOccupationReqInput
):
self
.
memory_saver_adapter
.
check_validity
(
tags
=
recv_req
.
tags
caller_name
=
"release_memory_occupation"
import
subprocess
)
self
.
stashed_model_static_state
=
_export_static_state
(
if
tags
is
None
:
self
.
tp_worker
.
worker
.
model_runner
.
model
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
)
self
.
memory_saver_adapter
.
pause
()
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
flush_cache
()
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_KV_CACHE
)
self
.
flush_cache
()
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
)
self
.
memory_saver_adapter
.
pause
(
GPU_MEMORY_TYPE_WEIGHTS
)
return
ReleaseMemoryOccupationReqOutput
()
return
ReleaseMemoryOccupationReqOutput
()
def
resume_memory_occupation
(
self
,
recv_req
:
ResumeMemoryOccupationReqInput
):
def
resume_memory_occupation
(
self
,
recv_req
:
ResumeMemoryOccupationReqInput
):
self
.
memory_saver_adapter
.
check_validity
(
caller_name
=
"resume_memory_occupation"
)
tags
=
recv_req
.
tags
self
.
memory_saver_adapter
.
resume
()
if
tags
is
None
or
len
(
tags
)
==
0
:
_import_static_state
(
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
,
GPU_MEMORY_TYPE_KV_CACHE
]
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
)
if
GPU_MEMORY_TYPE_WEIGHTS
in
tags
:
del
self
.
stashed_model_static_state
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_WEIGHTS
)
_import_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
,
)
del
self
.
stashed_model_static_state
if
GPU_MEMORY_TYPE_KV_CACHE
in
tags
:
self
.
memory_saver_adapter
.
resume
(
GPU_MEMORY_TYPE_KV_CACHE
)
return
ResumeMemoryOccupationReqOutput
()
return
ResumeMemoryOccupationReqOutput
()
def
slow_down
(
self
,
recv_req
:
SlowDownReqInput
):
def
slow_down
(
self
,
recv_req
:
SlowDownReqInput
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
3774f078
...
@@ -35,6 +35,7 @@ import torch
...
@@ -35,6 +35,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
debug_timing
,
get_bool_env_var
,
is_cuda
,
next_power_of_2
from
sglang.srt.utils
import
debug_timing
,
get_bool_env_var
,
is_cuda
,
next_power_of_2
...
@@ -54,6 +55,7 @@ class ReqToTokenPool:
...
@@ -54,6 +55,7 @@ class ReqToTokenPool:
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
):
):
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
enable
=
enable_memory_saver
)
)
...
@@ -61,7 +63,7 @@ class ReqToTokenPool:
...
@@ -61,7 +63,7 @@ class ReqToTokenPool:
self
.
size
=
size
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
device
=
device
self
.
device
=
device
with
memory_saver_adapter
.
region
():
with
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
self
.
req_to_token
=
torch
.
zeros
(
self
.
req_to_token
=
torch
.
zeros
(
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
)
...
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
...
@@ -292,7 +294,7 @@ class MHATokenToKVPool(KVCache):
)
)
def
_create_buffers
(
self
):
def
_create_buffers
(
self
):
with
self
.
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
with
(
with
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
enable_custom_mem_pool
if
self
.
enable_custom_mem_pool
...
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
...
@@ -610,7 +612,7 @@ class MLATokenToKVPool(KVCache):
else
:
else
:
self
.
custom_mem_pool
=
None
self
.
custom_mem_pool
=
None
with
self
.
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
with
(
with
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
custom_mem_pool
if
self
.
custom_mem_pool
...
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -753,7 +755,7 @@ class DoubleSparseTokenToKVPool(KVCache):
end_layer
,
end_layer
,
)
)
with
self
.
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
# [size, head_num, head_dim] for each layer
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
self
.
k_buffer
=
[
torch
.
zeros
(
torch
.
zeros
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
3774f078
...
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
...
@@ -30,6 +30,7 @@ from sglang.srt import debug_utils
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_tp_group
,
get_tp_group
,
get_world_group
,
get_world_group
,
...
@@ -222,6 +223,7 @@ class ModelRunner:
...
@@ -222,6 +223,7 @@ class ModelRunner:
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
server_args
=
self
.
server_args
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
self
.
server_args
.
enable_memory_saver
enable
=
self
.
server_args
.
enable_memory_saver
)
)
...
@@ -547,7 +549,7 @@ class ModelRunner:
...
@@ -547,7 +549,7 @@ class ModelRunner:
monkey_patch_vllm_parallel_state
()
monkey_patch_vllm_parallel_state
()
monkey_patch_isinstance_for_vllm_base_layer
()
monkey_patch_isinstance_for_vllm_base_layer
()
with
self
.
memory_saver_adapter
.
region
():
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_WEIGHTS
):
self
.
model
=
get_model
(
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
...
...
python/sglang/srt/torch_memory_saver_adapter.py
View file @
3774f078
import
logging
import
logging
import
threading
import
time
from
abc
import
ABC
from
abc
import
ABC
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
nullcontext
try
:
try
:
import
torch_memory_saver
import
torch_memory_saver
_primary
_memory_saver
=
torch_memory_saver
.
T
orch
M
emory
S
aver
()
_memory_saver
=
torch_memory_saver
.
t
orch
_m
emory
_s
aver
import_error
=
None
import_error
=
None
except
ImportError
as
e
:
except
ImportError
as
e
:
import_error
=
e
import_error
=
e
...
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
...
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
def
configure_subprocess
(
self
):
def
configure_subprocess
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
region
(
self
):
def
region
(
self
,
tag
:
str
):
raise
NotImplementedError
raise
NotImplementedError
def
pause
(
self
):
def
pause
(
self
,
tag
:
str
):
raise
NotImplementedError
raise
NotImplementedError
def
resume
(
self
):
def
resume
(
self
,
tag
:
str
):
raise
NotImplementedError
raise
NotImplementedError
@
property
@
property
...
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
...
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
class
_TorchMemorySaverAdapterReal
(
TorchMemorySaverAdapter
):
class
_TorchMemorySaverAdapterReal
(
TorchMemorySaverAdapter
):
"""Adapter for TorchMemorySaver with tag-based control"""
def
configure_subprocess
(
self
):
def
configure_subprocess
(
self
):
return
torch_memory_saver
.
configure_subprocess
()
return
torch_memory_saver
.
configure_subprocess
()
def
region
(
self
):
def
region
(
self
,
tag
:
str
):
return
_primary
_memory_saver
.
region
()
return
_memory_saver
.
region
(
tag
=
tag
)
def
pause
(
self
):
def
pause
(
self
,
tag
:
str
):
return
_primary
_memory_saver
.
pause
()
return
_memory_saver
.
pause
(
tag
=
tag
)
def
resume
(
self
):
def
resume
(
self
,
tag
:
str
):
return
_primary
_memory_saver
.
resume
()
return
_memory_saver
.
resume
(
tag
=
tag
)
@
property
@
property
def
enabled
(
self
):
def
enabled
(
self
):
return
_
primary
_memory_saver
.
enabled
return
_
memory_saver
is
not
None
and
_memory_saver
.
enabled
class
_TorchMemorySaverAdapterNoop
(
TorchMemorySaverAdapter
):
class
_TorchMemorySaverAdapterNoop
(
TorchMemorySaverAdapter
):
...
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
...
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
yield
yield
@
contextmanager
@
contextmanager
def
region
(
self
):
def
region
(
self
,
tag
:
str
):
yield
yield
def
pause
(
self
):
def
pause
(
self
,
tag
:
str
):
pass
pass
def
resume
(
self
):
def
resume
(
self
,
tag
:
str
):
pass
pass
@
property
@
property
...
...
python/sglang/test/test_utils.py
View file @
3774f078
...
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
...
@@ -37,6 +37,7 @@ from sglang.utils import get_exception_traceback
# General test models
# General test models
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE
=
"meta-llama/Llama-3.2-1B"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
...
...
test/srt/run_suite.py
View file @
3774f078
...
@@ -74,7 +74,6 @@ suites = {
...
@@ -74,7 +74,6 @@ suites = {
TestFile
(
"test_radix_attention.py"
,
105
),
TestFile
(
"test_radix_attention.py"
,
105
),
TestFile
(
"test_reasoning_content.py"
,
89
),
TestFile
(
"test_reasoning_content.py"
,
89
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_regex_constrained.py"
,
64
),
TestFile
(
"test_release_memory_occupation.py"
,
44
),
TestFile
(
"test_request_length_validation.py"
,
31
),
TestFile
(
"test_request_length_validation.py"
,
31
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_retract_decode.py"
,
54
),
TestFile
(
"test_server_args.py"
,
1
),
TestFile
(
"test_server_args.py"
,
1
),
...
@@ -146,6 +145,7 @@ suites = {
...
@@ -146,6 +145,7 @@ suites = {
TestFile
(
"test_patch_torch.py"
,
19
),
TestFile
(
"test_patch_torch.py"
,
19
),
TestFile
(
"test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"test_verl_engine_2_gpu.py"
,
64
),
TestFile
(
"test_verl_engine_2_gpu.py"
,
64
),
TestFile
(
"test_release_memory_occupation.py"
,
44
),
],
],
"per-commit-2-gpu-amd"
:
[
"per-commit-2-gpu-amd"
:
[
TestFile
(
"models/lora/test_lora_tp.py"
,
116
),
TestFile
(
"models/lora/test_lora_tp.py"
,
116
),
...
...
test/srt/test_release_memory_occupation.py
View file @
3774f078
"""Test memory release and resume operations for SGLang engine in hybrid RL training.
This test suite evaluates the SGLang engine's memory management capabilities, focusing
on releasing and resuming memory occupation for KV cache and model weights. It simulates
an RL workflow where the SGLang engine acts as a rollout engine for experience collection.
The process involves initializing the engine, sending a small number of requests to simulate
rollout, releasing memory to mimic offloading during RL training, resuming memory occupation,
updating weights with a trained HuggingFace model, and verifying the updated weights.
Detailed in our proposal (https://github.com/sgl-project/sglang/pull/7099), two test cases
are included:
1. Basic Release and Resume: Uses a lower mem_fraction_static (0.6) to control memory allocation
and avoid OOM errors carefully. This test simulates a scenario without multi-stage memory management,
ensuring the engine can release and resume memory occupation while maintaining functionality after
weight updates.
2. Multi-Stage Release and Resume: Employs a higher mem_fraction_static (0.85) to simulate higher
memory pressure, leveraging multi-stage memory management. It sequentially releases and resumes
KV cache and model weights, verifying memory deallocation and reallocation at each stage, and
ensuring correct weight updates and text generation.
3. Tensor Parallel Tests: Tests memory release and resume operations with different tensor parallel
configurations (tp=1, tp=2) to ensure proper memory management in distributed settings. For different
data parallel size, we test it in verl.
"""
import
gc
import
os
import
time
import
time
import
unittest
import
unittest
...
@@ -5,93 +34,221 @@ import torch
...
@@ -5,93 +34,221 @@ import torch
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
CustomTestCase
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE
,
CustomTestCase
,
)
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA
=
Tru
e
_DEBUG_EXTRA
=
Fals
e
class
TestReleaseMemoryOccupation
(
CustomTestCase
):
def
get_gpu_memory_gb
():
def
test_release_and_resume_occupation
(
self
):
return
torch
.
cuda
.
device_memory_used
()
/
1024
**
3
prompt
=
"Today is a sunny day and I like"
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
expect_output
=
" to spend it outdoors. I decided to"
class
TestReleaseMemoryOccupation
(
CustomTestCase
):
def
_setup_engine
(
self
,
model_name
,
mem_fraction_static
=
0.8
,
tp_size
=
1
):
"""Common setup for engine and HF model."""
engine
=
sgl
.
Engine
(
engine
=
sgl
.
Engine
(
model_path
=
model_name
,
model_path
=
model_name
,
random_seed
=
42
,
random_seed
=
42
,
enable_memory_saver
=
True
,
enable_memory_saver
=
True
,
mem_fraction_static
=
mem_fraction_static
,
tp_size
=
tp_size
,
# disable_cuda_graph=True, # for debugging only
# disable_cuda_graph=True, # for debugging only
)
)
hf_model_new
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
"bfloat16"
)
return
engine
def
_common_test_params
(
self
):
"""Common test parameters."""
return
{
"prompt"
:
"Today is a sunny day and I like"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
},
"expect_output_before_update_weights"
:
" to spend it outdoors. I decided to"
,
"expect_output_after_update_weights"
:
" to go for a walk. I like"
,
}
def
_test_initial_generation
(
self
,
engine
,
prompt
,
sampling_params
,
expect_output_before_update_weights
):
"""Test initial generation and memory allocation."""
print
(
"generate (#1)"
)
print
(
"generate (#1)"
)
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
self
.
assertEqual
(
outputs
,
expect_output
)
self
.
assertEqual
(
outputs
,
expect_output
_before_update_weights
)
if
_DEBUG_EXTRA
:
if
_DEBUG_EXTRA
:
time
.
sleep
(
3
)
time
.
sleep
(
3
)
self
.
assertEqual
(
def
test_release_and_resume_occupation
(
self
):
_try_allocate_big_tensor
(),
# Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM
False
,
model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
"Should not be able to allocate big tensors before releasing"
,
assert
(
)
torch
.
cuda
.
device_count
()
>=
2
),
"Need at least 2 GPUs for tensor parallel tests"
print
(
"release_memory_occupation start"
)
t
=
time
.
perf_counter
()
for
tp_size
in
[
1
,
2
]:
engine
.
release_memory_occupation
()
if
_DEBUG_EXTRA
:
print
(
f
"Testing tp_size=
{
tp_size
}
for test_release_and_resume_occupation"
)
print
(
"release_memory_occupation"
,
time
.
perf_counter
()
-
t
)
engine
=
self
.
_setup_engine
(
model_name
=
model_name
,
mem_fraction_static
=
0.6
,
tp_size
=
tp_size
if
_DEBUG_EXTRA
:
)
time
.
sleep
(
5
)
params
=
self
.
_common_test_params
()
self
.
assertEqual
(
self
.
_test_initial_generation
(
_try_allocate_big_tensor
(),
engine
,
True
,
params
[
"prompt"
],
"Should be able to allocate big tensors aftre releasing"
,
params
[
"sampling_params"
],
)
params
[
"expect_output_before_update_weights"
],
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
5
)
t
=
time
.
perf_counter
()
gpu_memory_usage_before_release
=
get_gpu_memory_gb
()
print
(
"resume_memory_occupation start"
)
engine
.
release_memory_occupation
()
t
=
time
.
perf_counter
()
gpu_memory_usage_after_release
=
get_gpu_memory_gb
()
engine
.
resume_memory_occupation
()
if
_DEBUG_EXTRA
:
self
.
assertLess
(
print
(
"resume_memory_occupation"
,
time
.
perf_counter
()
-
t
)
gpu_memory_usage_after_release
,
gpu_memory_usage_before_release
,
self
.
assertEqual
(
)
_try_allocate_big_tensor
(),
False
,
print
(
"Should not be able to allocate big tensors after resuming"
,
f
"Release took
{
time
.
perf_counter
()
-
t
:.
2
f
}
s, memory:
{
gpu_memory_usage_before_release
:.
1
f
}
GB →
{
gpu_memory_usage_after_release
:.
1
f
}
GB"
)
)
print
(
"update_weights_from_tensor"
)
if
_DEBUG_EXTRA
:
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
time
.
sleep
(
3
)
engine
.
update_weights_from_tensor
(
list
(
hf_model_new
.
named_parameters
()))
t
=
time
.
perf_counter
()
print
(
"generate (#2)"
)
engine
.
resume_memory_occupation
()
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
print
(
self
.
assertEqual
(
outputs
,
expect_output
)
f
"Resume took
{
time
.
perf_counter
()
-
t
:.
2
f
}
s, memory:
{
get_gpu_memory_gb
():.
1
f
}
GB"
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
4
)
hf_model_new
=
AutoModelForCausalLM
.
from_pretrained
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE
,
engine
.
shutdown
()
torch_dtype
=
"bfloat16"
,
device_map
=
"cuda"
,
)
engine
.
update_weights_from_tensor
(
list
(
hf_model_new
.
named_parameters
()))
# destroy the hf model
del
hf_model_new
torch
.
cuda
.
empty_cache
()
print
(
"generate (#2)"
)
outputs
=
engine
.
generate
(
params
[
"prompt"
],
params
[
"sampling_params"
])[
"text"
]
self
.
assertEqual
(
outputs
,
params
[
"expect_output_after_update_weights"
])
engine
.
shutdown
()
def
test_multi_stage_release_and_resume
(
self
):
# With multi-stage release and resume, we can set the memory fraction to 0.85 without concern of OOM
model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
def
_try_allocate_big_tensor
(
size
:
int
=
20_000_000_000
):
for
tp_size
in
[
1
,
2
]:
try
:
if
tp_size
==
2
and
torch
.
cuda
.
device_count
()
<
2
:
torch
.
empty
((
size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
continue
torch
.
cuda
.
empty_cache
()
return
True
print
(
f
"Testing tp_size=
{
tp_size
}
for test_multi_stage_release_and_resume"
)
except
torch
.
cuda
.
OutOfMemoryError
:
engine
=
sgl
.
Engine
(
return
False
model_path
=
model_name
,
random_seed
=
42
,
enable_memory_saver
=
True
,
mem_fraction_static
=
0.85
,
# Higher memory pressure
tp_size
=
tp_size
,
)
params
=
self
.
_common_test_params
()
self
.
_test_initial_generation
(
engine
,
params
[
"prompt"
],
params
[
"sampling_params"
],
params
[
"expect_output_before_update_weights"
],
)
t
=
time
.
perf_counter
()
gpu_memory_usage_before_release_kv_cache
=
get_gpu_memory_gb
()
engine
.
release_memory_occupation
(
tags
=
[
GPU_MEMORY_TYPE_KV_CACHE
])
gpu_memory_usage_after_release_kv_cache
=
get_gpu_memory_gb
()
self
.
assertLess
(
gpu_memory_usage_after_release_kv_cache
,
gpu_memory_usage_before_release_kv_cache
,
)
engine
.
release_memory_occupation
(
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
])
gpu_memory_usage_after_release_weights
=
get_gpu_memory_gb
()
self
.
assertLess
(
gpu_memory_usage_after_release_weights
,
gpu_memory_usage_after_release_kv_cache
,
)
print
(
f
"Release took
{
time
.
perf_counter
()
-
t
:.
2
f
}
s"
)
print
(
f
"Memory:
{
gpu_memory_usage_before_release_kv_cache
:.
1
f
}
→
{
gpu_memory_usage_after_release_kv_cache
:.
1
f
}
→
{
gpu_memory_usage_after_release_weights
:.
1
f
}
GB"
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
3
)
t
=
time
.
perf_counter
()
gpu_memory_usage_before_resume_weights
=
get_gpu_memory_gb
()
# gpu_memory_usage_after_release_weights and gpu_memory_usage_before_resume_weights should be close
self
.
assertAlmostEqual
(
gpu_memory_usage_after_release_weights
,
gpu_memory_usage_before_resume_weights
,
delta
=
3.0
,
)
print
(
f
"Resume weights took
{
time
.
perf_counter
()
-
t
:.
2
f
}
s"
)
engine
.
resume_memory_occupation
(
tags
=
[
GPU_MEMORY_TYPE_WEIGHTS
])
gpu_memory_usage_after_resume_weights
=
get_gpu_memory_gb
()
self
.
assertGreater
(
gpu_memory_usage_after_resume_weights
,
gpu_memory_usage_before_resume_weights
,
)
# Update weights from a trained model to serving engine, and then destroy the trained model
hf_model_new
=
AutoModelForCausalLM
.
from_pretrained
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE
,
torch_dtype
=
"bfloat16"
,
device_map
=
"cuda"
,
)
gpu_memory_usage_after_loaded_hf_model
=
get_gpu_memory_gb
()
engine
.
update_weights_from_tensor
(
list
(
hf_model_new
.
named_parameters
()))
# destroy the hf model
del
hf_model_new
torch
.
cuda
.
empty_cache
()
engine
.
resume_memory_occupation
(
tags
=
[
GPU_MEMORY_TYPE_KV_CACHE
])
gpu_memory_usage_after_resume_kv_cache
=
get_gpu_memory_gb
()
self
.
assertGreater
(
gpu_memory_usage_after_resume_kv_cache
,
gpu_memory_usage_after_resume_weights
,
)
print
(
f
"Resume + update took
{
time
.
perf_counter
()
-
t
:.
2
f
}
s"
)
print
(
f
"Memory:
{
gpu_memory_usage_before_resume_weights
:.
1
f
}
→
{
gpu_memory_usage_after_resume_weights
:.
1
f
}
→
{
gpu_memory_usage_after_loaded_hf_model
:.
1
f
}
→
{
gpu_memory_usage_after_resume_kv_cache
:.
1
f
}
GB"
)
print
(
"generate (#2)"
)
outputs
=
engine
.
generate
(
params
[
"prompt"
],
params
[
"sampling_params"
])[
"text"
]
self
.
assertEqual
(
outputs
,
params
[
"expect_output_after_update_weights"
])
engine
.
shutdown
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_verl_engine_2_gpu.py
View file @
3774f078
...
@@ -235,7 +235,8 @@ def _run_subprocess(
...
@@ -235,7 +235,8 @@ def _run_subprocess(
output_writer
.
send
(
execution_ok
)
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
output_writer
.
close
()
engine
.
shutdown
()
if
"engine"
in
locals
()
and
engine
is
not
None
:
engine
.
shutdown
()
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
...
...
test/srt/test_verl_engine_4_gpu.py
View file @
3774f078
...
@@ -249,7 +249,8 @@ def _run_subprocess(
...
@@ -249,7 +249,8 @@ def _run_subprocess(
output_writer
.
send
(
execution_ok
)
output_writer
.
send
(
execution_ok
)
output_writer
.
close
()
output_writer
.
close
()
engine
.
shutdown
()
if
"engine"
in
locals
()
and
engine
is
not
None
:
engine
.
shutdown
()
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
print
(
f
"subprocess[
{
rank
=
}
] end"
,
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment