Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
923f5183
Unverified
Commit
923f5183
authored
Jan 14, 2025
by
fzyzcjy
Committed by
GitHub
Jan 13, 2025
Browse files
CUDA-graph-compatible releasing and resuming KV cache and model weight memory (#2630)
parent
d08c77c4
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
407 additions
and
61 deletions
+407
-61
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+21
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+43
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+32
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+82
-48
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+16
-6
python/sglang/srt/server.py
python/sglang/srt/server.py
+46
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-1
python/sglang/torch_memory_saver_adapter.py
python/sglang/torch_memory_saver_adapter.py
+59
-0
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+2
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_release_memory_occupation.py
test/srt/test_release_memory_occupation.py
+98
-0
No files found.
python/pyproject.toml
View file @
923f5183
...
@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"]
...
@@ -44,6 +44,7 @@ srt_hpu = ["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
litellm
=
["litellm>=1.0.0"]
litellm
=
["litellm>=1.0.0"]
torch_memory_saver
=
["torch_memory_saver"]
test
=
[
test
=
[
"jsonlines"
,
"jsonlines"
,
"matplotlib"
,
"matplotlib"
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
923f5183
...
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
...
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput:
...
@@ -459,6 +457,26 @@ class GetWeightsByNameReqOutput:
parameter
:
list
parameter
:
list
@
dataclass
class
ReleaseMemoryOccupationReqInput
:
pass
@
dataclass
class
ReleaseMemoryOccupationReqOutput
:
pass
@
dataclass
class
ResumeMemoryOccupationReqInput
:
pass
@
dataclass
class
ResumeMemoryOccupationReqOutput
:
pass
@
dataclass
@
dataclass
class
AbortReq
:
class
AbortReq
:
# The request id
# The request id
...
...
python/sglang/srt/managers/scheduler.py
View file @
923f5183
...
@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
...
@@ -47,6 +47,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
...
@@ -88,6 +92,7 @@ from sglang.srt.utils import (
...
@@ -88,6 +92,7 @@ from sglang.srt.utils import (
set_random_seed
,
set_random_seed
,
suppress_other_loggers
,
suppress_other_loggers
,
)
)
from
sglang.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -357,6 +362,10 @@ class Scheduler:
...
@@ -357,6 +362,10 @@ class Scheduler:
t
.
start
()
t
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
)
# Init profiler
# Init profiler
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
if
os
.
getenv
(
"SGLANG_TORCH_PROFILER_DIR"
,
""
)
==
""
:
self
.
profiler
=
None
self
.
profiler
=
None
...
@@ -519,6 +528,12 @@ class Scheduler:
...
@@ -519,6 +528,12 @@ class Scheduler:
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
elif
isinstance
(
recv_req
,
GetWeightsByNameReqInput
):
parameter
=
self
.
get_weights_by_name
(
recv_req
)
parameter
=
self
.
get_weights_by_name
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
self
.
send_to_tokenizer
.
send_pyobj
(
GetWeightsByNameReqOutput
(
parameter
))
elif
isinstance
(
recv_req
,
ReleaseMemoryOccupationReqInput
):
self
.
release_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ReleaseMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ResumeMemoryOccupationReqInput
):
self
.
resume_memory_occupation
()
self
.
send_to_tokenizer
.
send_pyobj
(
ResumeMemoryOccupationReqOutput
())
elif
isinstance
(
recv_req
,
ProfileReq
):
elif
isinstance
(
recv_req
,
ProfileReq
):
if
recv_req
==
ProfileReq
.
START_PROFILE
:
if
recv_req
==
ProfileReq
.
START_PROFILE
:
self
.
start_profile
()
self
.
start_profile
()
...
@@ -1538,6 +1553,20 @@ class Scheduler:
...
@@ -1538,6 +1553,20 @@ class Scheduler:
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
parameter
=
self
.
tp_worker
.
get_weights_by_name
(
recv_req
)
return
parameter
return
parameter
def
release_memory_occupation
(
self
):
self
.
stashed_model_static_state
=
_export_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
)
self
.
memory_saver_adapter
.
pause
()
self
.
flush_cache
()
def
resume_memory_occupation
(
self
):
self
.
memory_saver_adapter
.
resume
()
_import_static_state
(
self
.
tp_worker
.
worker
.
model_runner
.
model
,
self
.
stashed_model_static_state
)
del
self
.
stashed_model_static_state
def
start_profile
(
self
)
->
None
:
def
start_profile
(
self
)
->
None
:
if
self
.
profiler
is
None
:
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
raise
RuntimeError
(
"Profiler is not enabled."
)
...
@@ -1576,6 +1605,20 @@ class Scheduler:
...
@@ -1576,6 +1605,20 @@ class Scheduler:
del
self
.
sessions
[
session_id
]
del
self
.
sessions
[
session_id
]
def
_export_static_state
(
model
):
return
dict
(
buffers
=
[
(
name
,
buffer
.
detach
().
clone
())
for
name
,
buffer
in
model
.
named_buffers
()
]
)
def
_import_static_state
(
model
,
static_params
):
self_named_buffers
=
dict
(
model
.
named_buffers
())
for
name
,
tensor
in
static_params
[
"buffers"
]:
self_named_buffers
[
name
][...]
=
tensor
def
run_scheduler_process
(
def
run_scheduler_process
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
923f5183
...
@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import (
...
@@ -53,6 +53,10 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
ReleaseMemoryOccupationReqInput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
SessionParams
,
SessionParams
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -188,6 +192,12 @@ class TokenizerManager:
...
@@ -188,6 +192,12 @@ class TokenizerManager:
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
get_weights_by_name_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
release_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
resume_memory_occupation_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
# Metrics
# Metrics
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
...
@@ -548,6 +558,22 @@ class TokenizerManager:
...
@@ -548,6 +558,22 @@ class TokenizerManager:
else
:
else
:
return
all_parameters
return
all_parameters
async
def
release_memory_occupation
(
self
,
obj
:
ReleaseMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
release_memory_occupation_communicator
(
obj
)
async
def
resume_memory_occupation
(
self
,
obj
:
ResumeMemoryOccupationReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
):
self
.
auto_create_handle_loop
()
await
self
.
resume_memory_occupation_communicator
(
obj
)
async
def
open_session
(
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
...
@@ -627,6 +653,8 @@ class TokenizerManager:
...
@@ -627,6 +653,8 @@ class TokenizerManager:
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromDistributedReqOutput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
ReleaseMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqOutput
,
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)):
if
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)):
...
@@ -709,6 +737,10 @@ class TokenizerManager:
...
@@ -709,6 +737,10 @@ class TokenizerManager:
self
.
update_weights_from_tensor_communicator
.
handle_recv
(
recv_obj
)
self
.
update_weights_from_tensor_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
elif
isinstance
(
recv_obj
,
GetWeightsByNameReqOutput
):
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
self
.
get_weights_by_name_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ReleaseMemoryOccupationReqOutput
):
self
.
release_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
elif
isinstance
(
recv_obj
,
ResumeMemoryOccupationReqOutput
):
self
.
resume_memory_occupation_communicator
.
handle_recv
(
recv_obj
)
else
:
else
:
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
raise
ValueError
(
f
"Invalid object:
{
recv_obj
=
}
"
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
923f5183
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
from
sglang.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
"""
"""
Memory pool.
Memory pool.
...
@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
...
@@ -42,13 +44,25 @@ GB = 1024 * 1024 * 1024
class
ReqToTokenPool
:
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
"""A memory pool that maps a request to its token locations."""
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
,
use_records
:
bool
):
def
__init__
(
self
,
size
:
int
,
max_context_len
:
int
,
device
:
str
,
use_records
:
bool
,
enable_memory_saver
:
bool
,
):
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
self
.
size
=
size
self
.
size
=
size
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
device
=
device
self
.
device
=
device
self
.
req_to_token
=
torch
.
zeros
(
with
memory_saver_adapter
.
region
():
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
self
.
req_to_token
=
torch
.
zeros
(
)
(
size
,
max_context_len
),
dtype
=
torch
.
int32
,
device
=
device
)
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
write_records
=
[]
self
.
write_records
=
[]
self
.
use_records
=
use_records
self
.
use_records
=
use_records
...
@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -189,8 +203,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_dim
:
int
,
head_dim
:
int
,
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
super
().
__init__
(
size
,
dtype
,
device
)
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
self
.
head_num
=
head_num
self
.
head_num
=
head_num
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
...
@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -202,24 +222,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
)
)
def
_create_buffers
(
self
):
def
_create_buffers
(
self
):
# [size, head_num, head_dim] for each layer
with
self
.
memory_saver_adapter
.
region
():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
# The padded slot 0 is used for writing dummy outputs from padded tokens.
torch
.
empty
(
self
.
k_buffer
=
[
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
torch
.
empty
(
dtype
=
self
.
store_dtype
,
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
device
=
self
.
device
,
dtype
=
self
.
store_dtype
,
)
device
=
self
.
device
,
for
_
in
range
(
self
.
layer_num
)
)
]
for
_
in
range
(
self
.
layer_num
)
self
.
v_buffer
=
[
]
torch
.
empty
(
self
.
v_buffer
=
[
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
torch
.
empty
(
dtype
=
self
.
store_dtype
,
(
self
.
size
+
1
,
self
.
head_num
,
self
.
head_dim
),
device
=
self
.
device
,
dtype
=
self
.
store_dtype
,
)
device
=
self
.
device
,
for
_
in
range
(
self
.
layer_num
)
)
]
for
_
in
range
(
self
.
layer_num
)
]
def
_clear_buffers
(
self
):
def
_clear_buffers
(
self
):
del
self
.
k_buffer
del
self
.
k_buffer
...
@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -307,19 +328,26 @@ class MLATokenToKVPool(BaseTokenToKVPool):
qk_rope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
super
().
__init__
(
size
,
dtype
,
device
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
kv_buffer
=
[
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
torch
.
empty
(
enable
=
enable_memory_saver
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
)
dtype
=
self
.
store_dtype
,
device
=
device
,
with
memory_saver_adapter
.
region
():
)
# The padded slot 0 is used for writing dummy outputs from padded tokens.
for
_
in
range
(
layer_num
)
self
.
kv_buffer
=
[
]
torch
.
empty
(
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
dtype
=
self
.
store_dtype
,
device
=
device
,
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
...
@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
...
@@ -360,26 +388,32 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
device
:
str
,
heavy_channel_num
:
int
,
heavy_channel_num
:
int
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
super
().
__init__
(
size
,
dtype
,
device
)
# [size, head_num, head_dim] for each layer
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
k_buffer
=
[
enable
=
enable_memory_saver
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
)
for
_
in
range
(
layer_num
)
]
with
memory_saver_adapter
.
region
():
self
.
v_buffer
=
[
# [size, head_num, head_dim] for each layer
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
self
.
k_buffer
=
[
for
_
in
range
(
layer_num
)
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
]
for
_
in
range
(
layer_num
)
]
# [size, head_num, heavy_channel_num] for each layer
self
.
v_buffer
=
[
self
.
label_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
torch
.
empty
(
for
_
in
range
(
layer_num
)
(
size
+
1
,
head_num
,
heavy_channel_num
),
dtype
=
dtype
,
device
=
device
]
)
for
_
in
range
(
layer_num
)
# [size, head_num, heavy_channel_num] for each layer
]
self
.
label_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
heavy_channel_num
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
]
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
923f5183
...
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
...
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
monkey_patch_vllm_p2p_access_check
,
monkey_patch_vllm_p2p_access_check
,
set_cpu_offload_max_bytes
,
set_cpu_offload_max_bytes
,
)
)
from
sglang.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -166,6 +167,10 @@ class ModelRunner:
...
@@ -166,6 +167,10 @@ class ModelRunner:
# Get memory before model loading
# Get memory before model loading
min_per_gpu_memory
=
self
.
init_torch_distributed
()
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
self
.
server_args
.
enable_memory_saver
)
# Load the model
# Load the model
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
load_model
()
self
.
load_model
()
...
@@ -272,11 +277,12 @@ class ModelRunner:
...
@@ -272,11 +277,12 @@ class ModelRunner:
monkey_patch_vllm_gguf_config
()
monkey_patch_vllm_gguf_config
()
# Load the model
# Load the model
self
.
model
=
get_model
(
with
self
.
memory_saver_adapter
.
region
():
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
load_config
=
self
.
load_config
,
model_config
=
self
.
model_config
,
device_config
=
DeviceConfig
(
self
.
device
),
load_config
=
self
.
load_config
,
)
device_config
=
DeviceConfig
(
self
.
device
),
)
if
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
self
.
server_args
.
quantization_param_path
is
not
None
:
if
self
.
server_args
.
quantization_param_path
is
not
None
:
...
@@ -417,7 +423,7 @@ class ModelRunner:
...
@@ -417,7 +423,7 @@ class ModelRunner:
logger
.
info
(
logger
.
info
(
f
"init custom process group: master_address=
{
master_address
}
, master_port=
{
master_port
}
, "
f
"init custom process group: master_address=
{
master_address
}
, master_port=
{
master_port
}
, "
f
"rank_offset=
{
rank_offset
}
, world_size=
{
world_size
}
, group_name=
{
group_name
}
, backend=
{
backend
}
"
f
"rank_offset=
{
rank_offset
}
,
rank=
{
rank
}
,
world_size=
{
world_size
}
, group_name=
{
group_name
}
, backend=
{
backend
}
"
)
)
try
:
try
:
...
@@ -590,6 +596,7 @@ class ModelRunner:
...
@@ -590,6 +596,7 @@ class ModelRunner:
max_context_len
=
self
.
model_config
.
context_len
+
4
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
self
.
device
,
device
=
self
.
device
,
use_records
=
False
,
use_records
=
False
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
@@ -602,6 +609,7 @@ class ModelRunner:
...
@@ -602,6 +609,7 @@ class ModelRunner:
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
elif
self
.
server_args
.
enable_double_sparsity
:
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
...
@@ -612,6 +620,7 @@ class ModelRunner:
...
@@ -612,6 +620,7 @@ class ModelRunner:
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
else
:
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
@@ -621,6 +630,7 @@ class ModelRunner:
...
@@ -621,6 +630,7 @@ class ModelRunner:
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
logger
.
info
(
logger
.
info
(
f
"Memory pool end. "
f
"Memory pool end. "
...
...
python/sglang/srt/server.py
View file @
923f5183
...
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
...
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import
torch
import
torch
from
sglang.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
# Fix a bug of Python threading
# Fix a bug of Python threading
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
setattr
(
threading
,
"_register_atexit"
,
lambda
*
args
,
**
kwargs
:
None
)
...
@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -57,6 +59,8 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
OpenSessionReqInput
,
OpenSessionReqInput
,
ReleaseMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromDistributedReqInput
,
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
...
@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
...
@@ -255,6 +259,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/release_memory_occupation"
,
methods
=
[
"GET"
,
"POST"
])
async
def
release_memory_occupation
(
obj
:
ReleaseMemoryOccupationReqInput
,
request
:
Request
):
"""Release GPU occupation temporarily"""
try
:
await
tokenizer_manager
.
release_memory_occupation
(
obj
,
request
)
except
Exception
as
e
:
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/resume_memory_occupation"
,
methods
=
[
"GET"
,
"POST"
])
async
def
resume_memory_occupation
(
obj
:
ResumeMemoryOccupationReqInput
,
request
:
Request
):
"""Resume GPU occupation"""
try
:
await
tokenizer_manager
.
resume_memory_occupation
(
obj
,
request
)
except
Exception
as
e
:
return
_create_error_response
(
e
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
"""Open a session, and return its unique session id."""
"""Open a session, and return its unique session id."""
...
@@ -438,6 +464,10 @@ def launch_engine(
...
@@ -438,6 +464,10 @@ def launch_engine(
server_args
.
model_path
,
server_args
.
tokenizer_path
server_args
.
model_path
,
server_args
.
tokenizer_path
)
)
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
server_args
.
enable_memory_saver
)
if
server_args
.
dp_size
==
1
:
if
server_args
.
dp_size
==
1
:
# Launch tensor parallel scheduler processes
# Launch tensor parallel scheduler processes
scheduler_procs
=
[]
scheduler_procs
=
[]
...
@@ -454,7 +484,8 @@ def launch_engine(
...
@@ -454,7 +484,8 @@ def launch_engine(
target
=
run_scheduler_process
,
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
)
)
proc
.
start
()
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
scheduler_pipe_readers
.
append
(
reader
)
...
@@ -471,7 +502,8 @@ def launch_engine(
...
@@ -471,7 +502,8 @@ def launch_engine(
target
=
run_data_parallel_controller_process
,
target
=
run_data_parallel_controller_process
,
args
=
(
server_args
,
port_args
,
writer
),
args
=
(
server_args
,
port_args
,
writer
),
)
)
proc
.
start
()
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
# Launch detokenizer process
# Launch detokenizer process
detoken_proc
=
mp
.
Process
(
detoken_proc
=
mp
.
Process
(
...
@@ -897,6 +929,18 @@ class Engine:
...
@@ -897,6 +929,18 @@ class Engine:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
))
return
loop
.
run_until_complete
(
tokenizer_manager
.
get_weights_by_name
(
obj
,
None
))
def
release_memory_occupation
(
self
):
"""Release GPU occupation temporarily"""
obj
=
ReleaseMemoryOccupationReqInput
()
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
tokenizer_manager
.
release_memory_occupation
(
obj
,
None
))
def
resume_memory_occupation
(
self
):
"""Resume GPU occupation"""
obj
=
ResumeMemoryOccupationReqInput
()
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
))
class
Runtime
:
class
Runtime
:
"""
"""
...
...
python/sglang/srt/server_args.py
View file @
923f5183
...
@@ -23,7 +23,6 @@ from typing import List, Optional
...
@@ -23,7 +23,6 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
get_hpu_memory_capacity
,
...
@@ -157,6 +156,7 @@ class ServerArgs:
...
@@ -157,6 +156,7 @@ class ServerArgs:
triton_attention_num_kv_splits
:
int
=
8
triton_attention_num_kv_splits
:
int
=
8
num_continuous_decode_steps
:
int
=
1
num_continuous_decode_steps
:
int
=
1
delete_ckpt_after_loading
:
bool
=
False
delete_ckpt_after_loading
:
bool
=
False
enable_memory_saver
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Set missing default values
# Set missing default values
...
@@ -854,6 +854,11 @@ class ServerArgs:
...
@@ -854,6 +854,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Delete the model checkpoint after loading the model."
,
help
=
"Delete the model checkpoint after loading the model."
,
)
)
parser
.
add_argument
(
"--enable-memory-saver"
,
action
=
"store_true"
,
help
=
"Allow saving memory using release_memory_occupation and resume_memory_occupation"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/torch_memory_saver_adapter.py
0 → 100644
View file @
923f5183
from
abc
import
ABC
from
contextlib
import
contextmanager
try
:
import
torch_memory_saver
_primary_memory_saver
=
torch_memory_saver
.
TorchMemorySaver
()
except
ImportError
:
pass
class
TorchMemorySaverAdapter
(
ABC
):
@
staticmethod
def
create
(
enable
:
bool
):
return
(
_TorchMemorySaverAdapterReal
()
if
enable
else
_TorchMemorySaverAdapterNoop
()
)
def
configure_subprocess
(
self
):
raise
NotImplementedError
def
region
(
self
):
raise
NotImplementedError
def
pause
(
self
):
raise
NotImplementedError
def
resume
(
self
):
raise
NotImplementedError
class
_TorchMemorySaverAdapterReal
(
TorchMemorySaverAdapter
):
def
configure_subprocess
(
self
):
return
torch_memory_saver
.
configure_subprocess
()
def
region
(
self
):
return
_primary_memory_saver
.
region
()
def
pause
(
self
):
return
_primary_memory_saver
.
pause
()
def
resume
(
self
):
return
_primary_memory_saver
.
resume
()
class
_TorchMemorySaverAdapterNoop
(
TorchMemorySaverAdapter
):
@
contextmanager
def
configure_subprocess
(
self
):
yield
@
contextmanager
def
region
(
self
):
yield
def
pause
(
self
):
pass
def
resume
(
self
):
pass
scripts/ci_install_dependency.sh
View file @
923f5183
...
@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh"
...
@@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh"
pip
install
--upgrade
pip
pip
install
--upgrade
pip
pip
install
-e
"python[all]"
--find-links
https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
pip
install
-e
"python[all]"
--find-links
https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
# Force reinstall flashinfer
# Force reinstall flashinfer
and torch_memory_saver
pip
install
flashinfer
==
0.1.6
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
flashinfer
==
0.1.6
--find-links
${
FLASHINFER_REPO
}
--force-reinstall
--no-deps
pip
install
torch_memory_saver
--force-reinstall
pip
install
transformers
==
4.45.2 sentence_transformers accelerate peft
pip
install
transformers
==
4.45.2 sentence_transformers accelerate peft
...
...
test/srt/run_suite.py
View file @
923f5183
...
@@ -29,6 +29,7 @@ suites = {
...
@@ -29,6 +29,7 @@ suites = {
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_pytorch_sampling_backend.py"
,
"test_pytorch_sampling_backend.py"
,
"test_radix_attention.py"
,
"test_radix_attention.py"
,
"test_release_memory_occupation.py"
,
"test_retract_decode.py"
,
"test_retract_decode.py"
,
"test_server_args.py"
,
"test_server_args.py"
,
"test_session_control.py"
,
"test_session_control.py"
,
...
...
test/srt/test_release_memory_occupation.py
0 → 100644
View file @
923f5183
import
time
import
unittest
import
torch
from
transformers
import
AutoModelForCausalLM
import
sglang
as
sgl
from
sglang.test.test_utils
import
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA
=
True
class
TestReleaseMemoryOccupation
(
unittest
.
TestCase
):
def
test_release_and_resume_occupation
(
self
):
prompt
=
"Today is a sunny day and I like"
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
model_name
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
expect_output
=
" to spend it outdoors. I decided to"
engine
=
sgl
.
Engine
(
model_path
=
model_name
,
random_seed
=
42
,
enable_memory_saver
=
True
,
# disable_cuda_graph=True, # for debugging only
)
hf_model_new
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
"bfloat16"
)
print
(
"generate (#1)"
)
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
self
.
assertEqual
(
outputs
,
expect_output
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
3
)
self
.
assertEqual
(
_try_allocate_big_tensor
(),
False
,
"Should not be able to allocate big tensors before releasing"
,
)
print
(
"release_memory_occupation start"
)
t
=
time
.
time
()
engine
.
release_memory_occupation
()
if
_DEBUG_EXTRA
:
print
(
"release_memory_occupation"
,
time
.
time
()
-
t
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
5
)
self
.
assertEqual
(
_try_allocate_big_tensor
(),
True
,
"Should be able to allocate big tensors aftre releasing"
,
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
5
)
print
(
"resume_memory_occupation start"
)
t
=
time
.
time
()
engine
.
resume_memory_occupation
()
if
_DEBUG_EXTRA
:
print
(
"resume_memory_occupation"
,
time
.
time
()
-
t
)
self
.
assertEqual
(
_try_allocate_big_tensor
(),
False
,
"Should not be able to allocate big tensors after resuming"
,
)
print
(
"update_weights_from_tensor"
)
# As if: PPO has updated hf model's weights, and now we sync it to SGLang
engine
.
update_weights_from_tensor
(
list
(
hf_model_new
.
named_parameters
()))
print
(
"generate (#2)"
)
outputs
=
engine
.
generate
(
prompt
,
sampling_params
)[
"text"
]
self
.
assertEqual
(
outputs
,
expect_output
)
if
_DEBUG_EXTRA
:
time
.
sleep
(
4
)
engine
.
shutdown
()
def
_try_allocate_big_tensor
(
size
:
int
=
20_000_000_000
):
try
:
torch
.
empty
((
size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
torch
.
cuda
.
empty_cache
()
return
True
except
torch
.
cuda
.
OutOfMemoryError
:
return
False
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment