Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
71a7f1d8
"vscode:/vscode.git/clone" did not exist on "f40b3de75865e6ef971d30d766a0ce765c662002"
Unverified
Commit
71a7f1d8
authored
Aug 25, 2025
by
fzyzcjy
Committed by
GitHub
Aug 25, 2025
Browse files
Offload tensors by sharding on GPU (#9536)
parent
433266c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
115 additions
and
0 deletions
+115
-0
python/sglang/srt/offloader.py
python/sglang/srt/offloader.py
+115
-0
No files found.
python/sglang/srt/offloader.py
View file @
71a7f1d8
...
...
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
@
staticmethod
def
create
(
mode
:
str
,
**
kwargs
)
->
"_BaseParamOffloader"
:
return
{
"meta"
:
_MetaParamOffloader
,
"cpu"
:
_CpuParamOffloader
,
"shm_cpu"
:
_ShmCpuParamOffloader
,
"sharded_gpu"
:
_ShardedGpuParamOffloader
,
...
...
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
raise
NotImplementedError
class
_MetaParamOffloader
(
_BaseParamOffloader
):
"""Usually used for debugging."""
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
_move_param_to_meta
(
module
,
param_name
)
def
create_device_tensor
(
self
):
return
torch
.
empty_like
(
self
.
_param
.
data
,
device
=
"cuda"
)
class
_CpuParamOffloader
(
_BaseParamOffloader
):
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
...
...
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
device
=
device
,
pin_memory
=
pin_memory
,
)
# ----------------------------------------- ShardedGpu ------------------------------------------------------
# TODO unify with ShmCpu mode
class
_ShardedGpuParamOffloader
(
_BaseParamOffloader
):
def
__init__
(
self
,
module
,
param_name
):
super
().
__init__
(
module
,
param_name
)
self
.
_rank
=
get_naive_distributed
().
get_rank
()
self
.
_world_size
=
get_naive_distributed
().
get_world_size
()
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
assert
get_tensor_model_parallel_world_size
()
==
1
,
"not yet support tp_size!=1"
assert
(
self
.
_param
.
data
.
is_contiguous
()
),
f
"not yet support non-contiguous tensor
{
self
.
_param
.
shape
=
}
{
self
.
_param
.
stride
()
=
}
"
if
self
.
_rank
==
0
:
_move_param_to_cpu
(
self
.
_param
,
pin_memory
=
True
)
else
:
_move_param_to_meta
(
self
.
_module
,
self
.
_param_name
)
self
.
sharded_param_handles
=
None
def
post_init
(
self
):
# check again since it may be changed
assert
(
self
.
_param
.
data
.
is_contiguous
()
),
f
"not yet support non-contiguous tensor
{
self
.
_param
.
shape
=
}
{
self
.
_param
.
stride
()
=
}
"
scatter_src
=
self
.
_param
.
data
logger
.
info
(
f
"[offloader] post_init
{
scatter_src
.
nbytes
=
}
{
scatter_src
.
dtype
=
}
{
scatter_src
.
shape
=
}
{
torch
.
cuda
.
memory_allocated
()
=
}
"
)
if
self
.
_rank
==
0
:
scatter_src
=
scatter_src
.
to
(
"cuda"
)
scatter_list
=
_even_chunk
(
scatter_src
,
self
.
_world_size
)
sharded_param
=
torch
.
empty
(
scatter_list
[
0
].
shape
,
dtype
=
scatter_list
[
0
].
dtype
,
device
=
"cuda"
)
self
.
sharded_param_handles
=
_create_shared_buffer_tensors
(
local_tensor
=
sharded_param
)
get_naive_distributed
().
scatter
(
sharded_param
,
scatter_list
if
self
.
_rank
==
0
else
None
)
_move_param_to_meta
(
self
.
_module
,
self
.
_param_name
)
def
create_device_tensor
(
self
):
output
=
_empty_strided_like
(
self
.
_param
,
device
=
"cuda"
)
output_chunks
=
output
.
chunk
(
self
.
_world_size
)
for
index
in
range
(
self
.
_world_size
):
src_rank
=
(
self
.
_rank
+
index
)
%
self
.
_world_size
src_buf
=
self
.
sharded_param_handles
[
src_rank
]
output_chunks
[
src_rank
].
copy_
(
src_buf
)
return
output
def
_even_chunk
(
x
:
torch
.
Tensor
,
chunks
:
int
):
assert
x
.
shape
[
0
]
%
chunks
==
0
,
f
"
{
x
.
shape
=
}
{
chunks
=
}
"
return
list
(
x
.
chunk
(
chunks
))
def
_create_shared_buffer_tensors
(
local_tensor
:
torch
.
Tensor
)
->
List
[
torch
.
Tensor
]:
self_rank
=
get_naive_distributed
().
get_rank
()
world_size
=
get_naive_distributed
().
get_world_size
()
object_list
=
get_naive_distributed
().
all_gather_object
(
dict
(
dup_serialized_local_tensor
=
[
(
None
if
interesting_rank
==
self_rank
else
MultiprocessingSerializer
.
serialize
(
local_tensor
)
)
for
interesting_rank
in
range
(
world_size
)
]
)
)
output_tensors
=
[]
for
output_rank
in
range
(
world_size
):
remote_serialized_tensor
=
object_list
[
output_rank
][
"dup_serialized_local_tensor"
][
self_rank
]
if
output_rank
==
self_rank
:
assert
remote_serialized_tensor
is
None
output_tensors
.
append
(
local_tensor
)
else
:
output_tensors
.
append
(
MultiprocessingSerializer
.
deserialize
(
remote_serialized_tensor
)
)
return
output_tensors
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment