Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a5317b2f
Unverified
Commit
a5317b2f
authored
Jun 28, 2025
by
Chunyuan WU
Committed by
GitHub
Jun 27, 2025
Browse files
[CPU] add optimizations for INT8 and FP8 DeepSeek (#6769)
Co-authored-by:
Zheng, Beilei
<
beilei.zheng@intel.com
>
parent
eb6c2c16
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
179 additions
and
3 deletions
+179
-3
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+43
-0
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+1
-1
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+51
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+83
-0
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
a5317b2f
...
...
@@ -291,7 +291,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
torch
.
float
),
# TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
topk_ids
,
Tru
e
,
# inplace
Fals
e
,
# inplace
# See [Note] inplace should be False in fused_experts.
False
,
# use_int8_w8a8
False
,
# use_fp8_w8a16
None
,
# w1_scale
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
a5317b2f
...
...
@@ -64,6 +64,7 @@ from sglang.srt.layers.quantization.utils import (
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
...
...
@@ -330,6 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
)
layer
.
input_scale
=
None
elif
_is_cpu
:
assert
(
_is_cpu_amx_available
),
"Fp8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"weight"
])
return
else
:
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
...
...
@@ -426,6 +433,17 @@ class Fp8LinearMethod(LinearMethodBase):
)
if
self
.
block_quant
:
if
getattr
(
layer
,
"use_intel_amx_backend"
,
False
):
return
torch
.
ops
.
sgl_kernel
.
fp8_scaled_mm_cpu
(
x
,
layer
.
weight
,
layer
.
weight_scale_inv
,
self
.
quant_config
.
weight_block_size
,
bias
,
x
.
dtype
,
True
,
# is_vnni
)
return
self
.
w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
...
...
@@ -746,6 +764,13 @@ class Fp8MoEMethod:
layer
.
w2_weight
.
data
=
shuffle_weight
(
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
if
_is_cpu
:
assert
(
_is_cpu_amx_available
),
"Fp8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
# If checkpoint is fp16 or bfloat16, quantize in place.
...
...
@@ -971,6 +996,24 @@ class Fp8MoEMethod:
routed_scaling_factor
=
routed_scaling_factor
,
)
if
getattr
(
layer
,
"use_intel_amx_backend"
,
False
):
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
False
,
# inplace See [Note] inplace should be False in fused_experts.
False
,
# use_int8_w8a8
True
,
# use_fp8_w8a16
layer
.
w13_weight_scale_inv
,
# w1_scale
layer
.
w2_weight_scale_inv
,
# w2_scale
self
.
quant_config
.
weight_block_size
,
# block_size
None
,
# a1_scale
None
,
# a2_scale
True
,
# is_vnni
)
if
_is_hip
:
ret
=
self
.
maybe_apply_hip_fused_experts
(
layer
,
...
...
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
a5317b2f
...
...
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
capability_tuple
=
get_device_capability
()
device_capability
=
(
-
1
if
capability
_tuple
is
None
if
all
(
capability
is
None
for
capability
in
capability_tuple
)
else
capability_tuple
[
0
]
*
10
+
capability_tuple
[
1
]
)
# Avoid circular import
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
a5317b2f
...
...
@@ -11,9 +11,17 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.utils
import
is_cuda
,
set_weight_attrs
from
sglang.srt.utils
import
(
_process_weight_after_loading
,
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
set_weight_attrs
,
)
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
...
...
@@ -72,6 +80,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_is_cpu
:
assert
(
_is_cpu_amx_available
),
"W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"weight"
])
return
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
...
...
@@ -112,6 +127,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
getattr
(
layer
,
"use_intel_amx_backend"
,
False
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm_with_quant
(
x
,
layer
.
weight
,
layer
.
weight_scale
,
bias
,
x
.
dtype
,
True
,
# is_vnni
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
...
...
@@ -206,6 +231,13 @@ class W8A8Int8MoEMethod:
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_is_cpu
:
assert
(
_is_cpu_amx_available
),
"W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_process_weight_after_loading
(
layer
,
[
"w13_weight"
,
"w2_weight"
])
return
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
...
...
@@ -252,6 +284,24 @@ class W8A8Int8MoEMethod:
routed_scaling_factor
=
routed_scaling_factor
,
)
if
getattr
(
layer
,
"use_intel_amx_backend"
,
False
):
return
torch
.
ops
.
sgl_kernel
.
fused_experts_cpu
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
False
,
# inplace See [Note] inplace should be False in fused_experts.
True
,
# use_int8_w8a8
False
,
# use_fp8_w8a16
layer
.
w13_weight_scale
,
# w1_scale
layer
.
w2_weight_scale
,
# w2_scale
None
,
# block_size
layer
.
w13_input_scale
,
# a1_scale
layer
.
w2_input_scale
,
# a2_scale
True
,
# is_vnni
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
a5317b2f
...
...
@@ -300,6 +300,9 @@ class DeepseekV2MoE(nn.Module):
),
)
self
.
shared_experts_is_int8
=
False
self
.
shared_experts_is_fp8
=
False
self
.
shared_experts_weight_block_size
=
None
if
config
.
n_shared_experts
is
not
None
and
self
.
num_fused_shared_experts
==
0
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
# disable tp for shared experts when enable deepep moe
...
...
@@ -316,6 +319,20 @@ class DeepseekV2MoE(nn.Module):
else
{}
),
)
self
.
shared_experts_is_int8
=
(
self
.
shared_experts
.
gate_up_proj
.
weight
.
dtype
==
torch
.
int8
)
self
.
shared_experts_is_fp8
=
(
self
.
shared_experts
.
gate_up_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
)
if
self
.
shared_experts_is_fp8
:
assert
(
self
.
shared_experts
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_block_size
==
self
.
shared_experts
.
down_proj
.
quant_method
.
quant_config
.
weight_block_size
)
self
.
shared_experts_weight_block_size
=
(
self
.
shared_experts
.
gate_up_proj
.
quant_method
.
quant_config
.
weight_block_size
)
self
.
top_k
=
config
.
num_experts_per_tok
...
...
@@ -394,6 +411,11 @@ class DeepseekV2MoE(nn.Module):
return
final_hidden_states
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
getattr
(
self
.
shared_experts
.
gate_up_proj
,
"use_intel_amx_backend"
,
False
):
return
self
.
forward_cpu
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
...
...
@@ -409,6 +431,59 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_cpu
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
fused_experts_out
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
assert
getattr
(
self
.
shared_experts
.
gate_up_proj
,
"use_intel_amx_backend"
,
False
)
==
getattr
(
self
.
shared_experts
.
down_proj
,
"use_intel_amx_backend"
,
False
)
# [Note] inplace should be False in fused_experts.
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
# While hidden_states is still needed in shared_expert.
final_hidden_states
=
torch
.
ops
.
sgl_kernel
.
shared_expert_cpu
(
hidden_states
,
self
.
shared_experts
.
gate_up_proj
.
weight
,
self
.
shared_experts
.
down_proj
.
weight
,
fused_experts_out
,
self
.
routed_scaling_factor
,
True
,
# inplace
self
.
shared_experts_is_int8
,
# use_int8_w8a8
self
.
shared_experts_is_fp8
,
# use_fp8_w8a16
(
self
.
shared_experts
.
gate_up_proj
.
weight_scale
if
self
.
shared_experts_is_int8
else
(
self
.
shared_experts
.
gate_up_proj
.
weight_scale_inv
if
self
.
shared_experts_is_fp8
else
None
)
),
# w1_scale
(
self
.
shared_experts
.
down_proj
.
weight_scale
if
self
.
shared_experts_is_int8
else
(
self
.
shared_experts
.
down_proj
.
weight_scale_inv
if
self
.
shared_experts_is_fp8
else
None
)
),
# w2_scale
(
self
.
shared_experts_weight_block_size
if
self
.
shared_experts_is_fp8
else
None
),
# block_size
None
,
# a1_scale
None
,
# a2_scale
True
,
# is_vnni
)
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
...
...
@@ -2107,6 +2182,14 @@ class DeepseekV2ForCausalLM(nn.Module):
)
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
# TODO: remove this after adding FP8 support in bmm cpu kernel
if
_is_cpu
and
_is_cpu_amx_available
and
w
.
dtype
==
torch
.
float8_e4m3fn
:
self_attn
.
w_kc
=
(
self_attn
.
w_kc
.
to
(
torch
.
bfloat16
)
*
self_attn
.
w_scale
)
self_attn
.
w_vc
=
(
self_attn
.
w_vc
.
to
(
torch
.
bfloat16
)
*
self_attn
.
w_scale
)
else
:
num_tiles_k
=
self_attn
.
qk_nope_head_dim
//
weight_block_size
[
1
]
num_tiles_n
=
self_attn
.
v_head_dim
//
weight_block_size
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment