Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
71a7f1d8
Unverified
Commit
71a7f1d8
authored
Aug 25, 2025
by
fzyzcjy
Committed by
GitHub
Aug 25, 2025
Browse files
Offload tensors by sharding on GPU (#9536)
parent
433266c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
0 deletions
+115
-0
python/sglang/srt/offloader.py
python/sglang/srt/offloader.py
+115
-0
No files found.
python/sglang/srt/offloader.py
View file @
71a7f1d8
...
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
...
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
@
staticmethod
@
staticmethod
def
create
(
mode
:
str
,
**
kwargs
)
->
"_BaseParamOffloader"
:
def
create
(
mode
:
str
,
**
kwargs
)
->
"_BaseParamOffloader"
:
return
{
return
{
"meta"
:
_MetaParamOffloader
,
"cpu"
:
_CpuParamOffloader
,
"cpu"
:
_CpuParamOffloader
,
"shm_cpu"
:
_ShmCpuParamOffloader
,
"shm_cpu"
:
_ShmCpuParamOffloader
,
"sharded_gpu"
:
_ShardedGpuParamOffloader
,
"sharded_gpu"
:
_ShardedGpuParamOffloader
,
...
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
...
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
raise
NotImplementedError
raise
NotImplementedError
class
_MetaParamOffloader
(
_BaseParamOffloader
):
"""Usually used for debugging."""
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
_move_param_to_meta
(
module
,
param_name
)
def
create_device_tensor
(
self
):
return
torch
.
empty_like
(
self
.
_param
.
data
,
device
=
"cuda"
)
class
_CpuParamOffloader
(
_BaseParamOffloader
):
class
_CpuParamOffloader
(
_BaseParamOffloader
):
def
__init__
(
self
,
module
,
param_name
):
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
super
().
__init__
(
module
,
param_name
)
...
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
...
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
device
=
device
,
device
=
device
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
# ----------------------------------------- ShardedGpu ------------------------------------------------------
# TODO unify with ShmCpu mode
class
_ShardedGpuParamOffloader
(
_BaseParamOffloader
):
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
self
.
_rank
=
get_naive_distributed
().
get_rank
()
self
.
_world_size
=
get_naive_distributed
().
get_world_size
()
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
assert
get_tensor_model_parallel_world_size
()
==
1
,
"not yet support tp_size!=1"
assert
(
self
.
_param
.
data
.
is_contiguous
()
),
f
"not yet support non-contiguous tensor
{
self
.
_param
.
shape
=
}
{
self
.
_param
.
stride
()
=
}
"
if
self
.
_rank
==
0
:
_move_param_to_cpu
(
self
.
_param
,
pin_memory
=
True
)
else
:
_move_param_to_meta
(
self
.
_module
,
self
.
_param_name
)
self
.
sharded_param_handles
=
None
def
post_init
(
self
):
# check again since it may be changed
assert
(
self
.
_param
.
data
.
is_contiguous
()
),
f
"not yet support non-contiguous tensor
{
self
.
_param
.
shape
=
}
{
self
.
_param
.
stride
()
=
}
"
scatter_src
=
self
.
_param
.
data
logger
.
info
(
f
"[offloader] post_init
{
scatter_src
.
nbytes
=
}
{
scatter_src
.
dtype
=
}
{
scatter_src
.
shape
=
}
{
torch
.
cuda
.
memory_allocated
()
=
}
"
)
if
self
.
_rank
==
0
:
scatter_src
=
scatter_src
.
to
(
"cuda"
)
scatter_list
=
_even_chunk
(
scatter_src
,
self
.
_world_size
)
sharded_param
=
torch
.
empty
(
scatter_list
[
0
].
shape
,
dtype
=
scatter_list
[
0
].
dtype
,
device
=
"cuda"
)
self
.
sharded_param_handles
=
_create_shared_buffer_tensors
(
local_tensor
=
sharded_param
)
get_naive_distributed
().
scatter
(
sharded_param
,
scatter_list
if
self
.
_rank
==
0
else
None
)
_move_param_to_meta
(
self
.
_module
,
self
.
_param_name
)
def
create_device_tensor
(
self
):
output
=
_empty_strided_like
(
self
.
_param
,
device
=
"cuda"
)
output_chunks
=
output
.
chunk
(
self
.
_world_size
)
for
index
in
range
(
self
.
_world_size
):
src_rank
=
(
self
.
_rank
+
index
)
%
self
.
_world_size
src_buf
=
self
.
sharded_param_handles
[
src_rank
]
output_chunks
[
src_rank
].
copy_
(
src_buf
)
return
output
def
_even_chunk
(
x
:
torch
.
Tensor
,
chunks
:
int
):
assert
x
.
shape
[
0
]
%
chunks
==
0
,
f
"
{
x
.
shape
=
}
{
chunks
=
}
"
return
list
(
x
.
chunk
(
chunks
))
def
_create_shared_buffer_tensors
(
local_tensor
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
self_rank
=
get_naive_distributed
().
get_rank
()
world_size
=
get_naive_distributed
().
get_world_size
()
object_list
=
get_naive_distributed
().
all_gather_object
(
dict
(
dup_serialized_local_tensor
=
[
(
None
if
interesting_rank
==
self_rank
else
MultiprocessingSerializer
.
serialize
(
local_tensor
)
)
for
interesting_rank
in
range
(
world_size
)
]
)
)
output_tensors
=
[]
for
output_rank
in
range
(
world_size
):
remote_serialized_tensor
=
object_list
[
output_rank
][
"dup_serialized_local_tensor"
][
self_rank
]
if
output_rank
==
self_rank
:
assert
remote_serialized_tensor
is
None
output_tensors
.
append
(
local_tensor
)
else
:
output_tensors
.
append
(
MultiprocessingSerializer
.
deserialize
(
remote_serialized_tensor
)
)
return
output_tensors
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment