Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fa46e2bd
Unverified
Commit
fa46e2bd
authored
Sep 14, 2025
by
fzyzcjy
Committed by
GitHub
Sep 14, 2025
Browse files
Support offloading in fp8 (#9948)
parent
b047b553
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
95 additions
and
17 deletions
+95
-17
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+52
-10
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+7
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-2
python/sglang/srt/offloader.py
python/sglang/srt/offloader.py
+27
-3
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
fa46e2bd
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.moe
import
(
...
...
@@ -31,7 +33,15 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.offloader
import
get_offloader
from
sglang.srt.utils
import
(
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_cuda
,
is_hip
,
is_npu
,
)
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
(
...
...
@@ -535,6 +545,24 @@ class DeepEPMoE(EPMoE):
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
w13_weight_fp8
=
(
self
.
w13_weight
,
(
self
.
w13_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w13_weight_scale
),
)
w2_weight_fp8
=
(
self
.
w2_weight
,
(
self
.
w2_weight_scale_inv
if
self
.
use_block_quant
else
self
.
w2_weight_scale
),
)
hidden_states_fp8_shape
=
hidden_states_fp8
.
shape
hidden_states_fp8_device
=
hidden_states_fp8
.
device
hidden_states_fp8_dtype
=
hidden_states_fp8
.
dtype
...
...
@@ -565,12 +593,17 @@ class DeepEPMoE(EPMoE):
)
output_index
=
torch
.
empty_like
(
topk_idx
)
num_recv_tokens_per_expert_gpu
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
,
).
cuda
(
non_blocking
=
True
)
if
get_offloader
().
forbid_copy_engine_usage
:
num_recv_tokens_per_expert_gpu
=
copy_list_to_gpu_no_ce
(
num_recv_tokens_per_expert
)
else
:
num_recv_tokens_per_expert_gpu
=
torch
.
tensor
(
num_recv_tokens_per_expert
,
dtype
=
torch
.
int32
,
pin_memory
=
True
,
device
=
"cpu"
,
).
cuda
(
non_blocking
=
True
)
expert_start_loc
=
torch
.
empty_like
(
num_recv_tokens_per_expert_gpu
)
ep_scatter
(
...
...
@@ -595,7 +628,7 @@ class DeepEPMoE(EPMoE):
if
not
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
input_tensor
,
w13_weight_fp8
,
gateup_output
,
m_indices
)
del
input_tensor
down_input
=
torch
.
empty
(
...
...
@@ -625,7 +658,7 @@ class DeepEPMoE(EPMoE):
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_contig
(
(
down_input_fp8
,
down_input_scale
),
self
.
w2_weight_fp8
,
w2_weight_fp8
,
down_output
,
m_indices
,
)
...
...
@@ -885,3 +918,12 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
if
get_moe_expert_parallel_world_size
()
>
1
:
return
EPMoE
return
FusedMoE
def
copy_list_to_gpu_no_ce
(
arr
:
List
[
int
]):
from
sgl_kernel.elementwise
import
copy_to_gpu_no_ce
tensor_cpu
=
torch
.
tensor
(
arr
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
tensor_gpu
=
torch
.
empty_like
(
tensor_cpu
,
device
=
"cuda"
)
copy_to_gpu_no_ce
(
tensor_cpu
,
tensor_gpu
)
return
tensor_gpu
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
fa46e2bd
...
...
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
import
torch
from
sglang.srt
import
offloader
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
...
...
@@ -417,10 +418,14 @@ def block_quant_dequant(
def
requant_weight_ue8m0_inplace
(
weight
,
weight_scale_inv
,
weight_block_size
):
assert
isinstance
(
weight
,
torch
.
nn
.
Parameter
)
assert
isinstance
(
weight_scale_inv
,
torch
.
nn
.
Parameter
)
weight
.
data
,
weight_scale_inv
.
data
=
_requant_weight_ue8m0
(
weight
,
weight_scale_inv
,
weight_block_size
new_weight
,
new_weight_scale_inv
=
_requant_weight_ue8m0
(
weight
.
to
(
weight_scale_inv
.
device
),
weight_scale_inv
,
weight_block_size
)
offloader
.
update_param
(
weight
,
new_weight
)
weight_scale_inv
.
data
=
new_weight_scale_inv
def
_requant_weight_ue8m0
(
weight
:
torch
.
Tensor
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
fa46e2bd
...
...
@@ -2244,8 +2244,15 @@ class DeepseekV2Model(nn.Module):
[
"w13_weight"
,
"w2_weight"
,
"w13_blockscale_swizzled"
,
"w2_blockscale_swizzled"
,
# only for nvfp4
*
(
[
"w13_blockscale_swizzled"
,
"w2_blockscale_swizzled"
,
]
if
hasattr
(
module
,
"w13_blockscale_swizzled"
)
else
[]
),
]
if
isinstance
(
module
,
FusedMoE
)
else
[]
...
...
python/sglang/srt/offloader.py
View file @
fa46e2bd
...
...
@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
def
post_init
(
self
):
pass
@
property
def
forbid_copy_engine_usage
(
self
):
return
False
class
NoopOffloader
(
BaseOffloader
):
pass
...
...
@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
for
i
in
range
(
self
.
prefetch_step
):
self
.
offloaders
[
i
].
start_onload
()
@
property
def
forbid_copy_engine_usage
(
self
):
return
self
.
mode
==
"cpu"
def
_hook_module_forward_for_offloader
(
index
,
module
,
offloaders
,
prefetch_step
):
def
_on_forward_end
():
...
...
@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
return
self
.
shm_cpu_data
.
to
(
"cuda"
,
non_blocking
=
True
)
def
update_param
(
param
,
new_tensor
):
"""Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
if
param
.
device
==
new_tensor
.
device
:
param
.
data
=
new_tensor
else
:
assert
param
.
device
==
torch
.
device
(
"cpu"
),
f
"
{
param
.
device
=
}
{
new_tensor
.
device
=
}
"
param
.
data
=
_create_cpu_data
(
new_tensor
,
pin_memory
=
True
)
def
_move_param_to_cpu
(
param
,
pin_memory
:
bool
):
param
.
data
=
_create_cpu_data
(
param
.
data
,
pin_memory
=
pin_memory
)
def
_create_cpu_data
(
data
,
pin_memory
:
bool
):
cpu_data
=
_empty_strided_like
(
param
.
data
,
data
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
param
.
data
)
param
.
data
=
cpu_data
cpu_data
.
copy_
(
data
)
return
cpu_data
def
_move_param_to_meta
(
module
,
param_name
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment