Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
55d336cb
Unverified
Commit
55d336cb
authored
Aug 21, 2025
by
fzyzcjy
Committed by
GitHub
Aug 21, 2025
Browse files
Refactor weight offloading logic (#8521)
parent
de4990a5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
74 deletions
+141
-74
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+10
-4
python/sglang/srt/offloader.py
python/sglang/srt/offloader.py
+122
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-70
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
55d336cb
...
...
@@ -96,6 +96,11 @@ from sglang.srt.model_loader import get_model
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
,
get_model_loader
from
sglang.srt.model_loader.utils
import
set_default_torch_dtype
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.offloader
import
(
create_offloader_from_server_args
,
get_offloader
,
set_offloader
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_reductions
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -118,7 +123,6 @@ from sglang.srt.utils import (
is_npu
,
monkey_patch_p2p_access_check
,
monkey_patch_vllm_gguf_config
,
set_cpu_offload_max_bytes
,
set_cuda_arch
,
)
from
sglang.srt.weight_sync.tensor_bucket
import
(
...
...
@@ -222,9 +226,6 @@ class ModelRunner:
}
)
# CPU offload
set_cpu_offload_max_bytes
(
int
(
server_args
.
cpu_offload_gb
*
1024
**
3
))
# Init OpenMP threads binding for CPU
if
self
.
device
==
"cpu"
:
self
.
init_threads_binding
()
...
...
@@ -232,6 +233,9 @@ class ModelRunner:
# Get memory before model loading
min_per_gpu_memory
=
self
.
init_torch_distributed
()
# CPU offload
set_offloader
(
create_offloader_from_server_args
(
server_args
))
# Update deep gemm configure
if
deep_gemm_wrapper
.
ENABLE_JIT_DEEPGEMM
:
deep_gemm_wrapper
.
update_deep_gemm_config
(
gpu_id
,
server_args
)
...
...
@@ -690,6 +694,8 @@ class ModelRunner:
monkey_patch_vllm_parallel_state
(
reverse
=
True
)
monkey_patch_isinstance_for_vllm_base_layer
(
reverse
=
True
)
get_offloader
().
post_init
()
if
self
.
server_args
.
kv_cache_dtype
==
"fp8_e4m3"
:
if
self
.
server_args
.
quantization_param_path
is
not
None
:
if
callable
(
getattr
(
self
.
model
,
"load_kv_cache_scales"
,
None
)):
...
...
python/sglang/srt/offloader.py
0 → 100644
View file @
55d336cb
import
logging
from
abc
import
ABC
from
typing
import
Callable
,
Generator
,
List
,
Optional
import
torch
from
torch.func
import
functional_call
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
is_pin_memory_available
logger
=
logging
.
getLogger
(
__name__
)
_SubmoduleAccessor
=
Callable
[[
torch
.
nn
.
Module
],
torch
.
nn
.
Module
]
_WhitelistParamNamesCreator
=
Callable
[[
torch
.
nn
.
Module
],
List
[
str
]]
class
BaseOffloader
(
ABC
):
def
wrap_modules
(
self
,
all_modules_generator
:
Generator
[
torch
.
nn
.
Module
,
None
,
None
],
submodule_accessor
:
Optional
[
_SubmoduleAccessor
]
=
None
,
whitelist_param_names_creator
:
Optional
[
_WhitelistParamNamesCreator
]
=
None
,
):
return
list
(
all_modules_generator
)
def
post_init
(
self
):
pass
class
NoopOffloader
(
BaseOffloader
):
pass
# For simplicity use singleton, but can surely support multi instance
_instance
:
Optional
[
BaseOffloader
]
=
NoopOffloader
()
def
get_offloader
():
assert
_instance
is
not
None
return
_instance
def
set_offloader
(
instance
:
BaseOffloader
):
global
_instance
_instance
=
instance
def
create_offloader_from_server_args
(
server_args
:
ServerArgs
):
if
server_args
.
cpu_offload_gb
>
0
:
return
OffloaderV1
(
cpu_offload_max_bytes
=
int
(
server_args
.
cpu_offload_gb
*
1024
**
3
)
)
return
NoopOffloader
()
class
OffloaderV1
(
BaseOffloader
):
def
__init__
(
self
,
cpu_offload_max_bytes
:
int
):
self
.
_cpu_offload_bytes
=
0
self
.
_cpu_offload_max_bytes
=
cpu_offload_max_bytes
def
wrap_modules
(
self
,
all_modules_generator
:
Generator
[
torch
.
nn
.
Module
,
None
,
None
],
submodule_accessor
:
Optional
[
_SubmoduleAccessor
]
=
None
,
whitelist_param_names_creator
:
Optional
[
_WhitelistParamNamesCreator
]
=
None
,
):
return
[
self
.
maybe_offload_to_cpu
(
module
)
for
module
in
all_modules_generator
]
def
maybe_offload_to_cpu
(
self
,
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
if
(
params
:
=
next
(
module
.
parameters
(),
None
))
is
None
:
return
module
device
=
params
.
device
if
device
==
torch
.
device
(
"cpu"
):
return
module
if
self
.
_cpu_offload_bytes
>=
self
.
_cpu_offload_max_bytes
:
return
module
pin_memory
=
is_pin_memory_available
()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters
=
False
for
p
in
module
.
parameters
():
if
self
.
_cpu_offload_bytes
>=
self
.
_cpu_offload_max_bytes
:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
self
.
_cpu_offload_bytes
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
offloaded_parameters
=
True
if
offloaded_parameters
:
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
module
.
state_dict
().
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
module
.
forward
=
forward
return
module
python/sglang/srt/utils.py
View file @
55d336cb
...
...
@@ -438,72 +438,6 @@ def is_pin_memory_available() -> bool:
return
torch
.
cuda
.
is_available
()
_CPU_OFFLOAD_BYTES
=
0
_CPU_OFFLOAD_MAX_BYTES
=
0
def
set_cpu_offload_max_bytes
(
max_bytes
:
int
)
->
None
:
global
_CPU_OFFLOAD_MAX_BYTES
,
_CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES
=
0
_CPU_OFFLOAD_MAX_BYTES
=
max_bytes
def
maybe_offload_to_cpu
(
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
if
(
params
:
=
next
(
module
.
parameters
(),
None
))
is
None
:
return
module
device
=
params
.
device
if
device
==
torch
.
device
(
"cpu"
):
return
module
global
_CPU_OFFLOAD_MAX_BYTES
,
_CPU_OFFLOAD_BYTES
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
return
module
pin_memory
=
is_pin_memory_available
()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters
=
False
for
p
in
module
.
parameters
():
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
offloaded_parameters
=
True
if
offloaded_parameters
:
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
module
.
state_dict
().
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
module
.
forward
=
forward
return
module
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
layer_id
:
int
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
...
...
@@ -516,11 +450,13 @@ def make_layers(
pp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
return_tuple
:
bool
=
False
,
offloader_kwargs
:
Dict
[
str
,
Any
]
=
{},
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
"""Make a list of layers with the given layer function"""
# circula imports
from
sglang.srt.distributed
import
get_pp_indices
from
sglang.srt.layers.utils
import
PPMissingLayer
from
sglang.srt.offloader
import
get_offloader
assert
not
pp_size
or
num_hidden_layers
>=
pp_size
start_layer
,
end_layer
=
(
...
...
@@ -534,10 +470,13 @@ def make_layers(
)
modules
=
torch
.
nn
.
ModuleList
(
[
PPMissingLayer
(
return_tuple
=
return_tuple
)
for
_
in
range
(
start_layer
)]
+
[
maybe_offload_to_cpu
(
layer_fn
(
idx
=
idx
,
prefix
=
add_prefix
(
idx
,
prefix
)))
for
idx
in
range
(
start_layer
,
end_layer
)
]
+
get_offloader
().
wrap_modules
(
(
layer_fn
(
idx
=
idx
,
prefix
=
add_prefix
(
idx
,
prefix
))
for
idx
in
range
(
start_layer
,
end_layer
)
),
**
offloader_kwargs
,
)
+
[
PPMissingLayer
(
return_tuple
=
return_tuple
)
for
_
in
range
(
end_layer
,
num_hidden_layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment