Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f194e14f
Unverified
Commit
f194e14f
authored
May 16, 2025
by
fzyzcjy
Committed by
GitHub
May 15, 2025
Browse files
Reduce MoE memory usage (#6147)
parent
cfc9f9ab
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
75 additions
and
40 deletions
+75
-40
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+10
-2
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+58
-35
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
f194e14f
...
...
@@ -3,10 +3,9 @@ from typing import List, Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
dispose_tensor
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
scale_a
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
c_dtype
=
None
,
):
assert
weight_column_major
==
True
# TODO: more
if
use_fp8_w8a8
and
block_shape
is
None
:
assert
scale_a
is
not
None
and
scale_b
is
not
None
if
block_shape
is
not
None
:
a_original
=
a
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a
,
scale_a
=
per_token_group_quant_fp8
(
a
,
block_k
)
...
...
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
assert
triton
.
cdiv
(
b
.
shape
[
-
2
],
block_n
)
==
scale_b
.
shape
[
-
2
]
assert
triton
.
cdiv
(
b
.
shape
[
-
1
],
block_k
)
==
scale_b
.
shape
[
-
1
]
dispose_tensor
(
a_original
)
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
config
=
{
...
...
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
m_num_tiles_indptr
,
seg_indptr
,
batch_size
,
config
[
"BLOCK_SIZE_M"
]
)
if
c
is
None
:
assert
c_dtype
is
not
None
c
=
torch
.
empty
(
a
.
shape
[
0
],
b
.
shape
[
1
],
device
=
a
.
device
,
dtype
=
c_dtype
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
a
.
size
(
0
),
META
[
"BLOCK_SIZE_M"
])
+
batch_size
,
triton
.
cdiv
(
b
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
f194e14f
...
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
dispose_tensor
,
is_hip
,
set_weight_attrs
_is_hip
=
is_hip
()
...
...
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
c_dtype
=
None
,
):
if
self
.
use_flashinfer
:
# TODO: flashinfer
...
...
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a
,
scale_b
,
block_shape
=
block_shape
,
c_dtype
=
c_dtype
,
)
return
c
...
...
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
self
.
grouped_gemm_runner
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
assert
self
.
quant_method
is
not
None
if
self
.
grouped_gemm_runner
is
None
:
...
...
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
)
dispose_tensor
(
hidden_states
)
seg_indptr_cur_rank
=
seg_indptr
[
self
.
start_expert_id
:
self
.
end_expert_id
+
2
]
weight_indices_cur_rank
=
torch
.
arange
(
0
,
self
.
num_experts_per_partition
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
dtype
=
torch
.
int64
,
)
# GroupGemm-0
gateup_output
=
torch
.
empty
(
gateup_input
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
gateup_output
=
self
.
grouped_gemm_runner
(
a
=
gateup_input
,
b
=
self
.
w13_weight
,
c
=
gateup_output
,
c
=
None
,
c_dtype
=
hidden_states_dtype
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr_cur_rank
,
...
...
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
),
block_shape
=
self
.
block_shape
,
)
del
gateup_input
# Act
down_input
=
torch
.
empty
(
...
...
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
dtype
=
(
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states
.
dtype
else
hidden_states
_
dtype
),
)
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
)
if
self
.
activation
==
"silu"
:
...
...
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
)
else
:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
del
gateup_output
# GroupGemm-1
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
self
.
w2_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
_
device
,
dtype
=
hidden_states
_
dtype
,
)
down_output
=
self
.
grouped_gemm_runner
(
a
=
down_input
,
...
...
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
),
block_shape
=
self
.
block_shape
,
)
del
down_input
# PostReorder
output
=
torch
.
empty_like
(
hidden_states
)
post_reorder_triton_kernel
[(
hidden_states
.
size
(
0
),)](
output
=
torch
.
empty
(
hidden_states_shape
,
dtype
=
hidden_states_dtype
,
device
=
hidden_states_device
)
post_reorder_triton_kernel
[(
hidden_states_shape
[
0
],)](
down_output
,
output
,
src2dst
,
...
...
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
top_k
,
hidden_states
.
size
(
1
)
,
hidden_states
_shape
[
1
]
,
BLOCK_SIZE
=
512
,
)
return
output
...
...
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
reorder_topk_ids
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
):
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
...
...
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
)
# GroupGemm-0
gateup_output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
if
hidden_states
.
shape
[
0
]
>
0
:
gateup_output
=
self
.
grouped_gemm_runner
(
a
=
hidden_states
,
b
=
self
.
w13_weight
,
c
=
gateup_output
,
c
=
None
,
c_dtype
=
hidden_states
.
dtype
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
...
...
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
),
block_shape
=
self
.
block_shape
,
)
else
:
gateup_output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
# Act
down_input
=
torch
.
empty
(
...
...
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
dtype
=
(
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states
.
dtype
else
hidden_states
_
dtype
),
)
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
)
if
self
.
activation
==
"silu"
:
...
...
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
else
:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
del
gateup_output
# GroupGemm-1
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
self
.
w2_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
_
device
,
dtype
=
hidden_states
_
dtype
,
)
if
down_input
.
shape
[
0
]
>
0
:
down_output
=
self
.
grouped_gemm_runner
(
...
...
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
gather_out
=
torch
.
empty_like
(
hidden_states_fp8
,
device
=
hidden_states_fp8
.
device
,
dtype
=
torch
.
bfloat16
,
)
hidden_states_fp8_shape
=
hidden_states_fp8
.
shape
hidden_states_fp8_device
=
hidden_states_fp8
.
device
hidden_states_fp8_dtype
=
hidden_states_fp8
.
dtype
input_tensor
=
[
torch
.
empty
(
...
...
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
m_indices
,
output_index
,
)
dispose_tensor
(
hidden_states_fp8
)
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
_
device
,
dtype
=
torch
.
bfloat16
,
)
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
)
del
input_tensor
down_input
=
torch
.
empty
(
(
all_tokens
,
...
...
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
dtype
=
torch
.
bfloat16
,
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
del
gateup_output
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
_
device
,
dtype
=
torch
.
bfloat16
,
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
)
del
down_input
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
down_input_fp8
,
down_input_scale
),
...
...
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
down_output
,
m_indices
,
)
del
down_input_fp8
,
down_input_scale
gather_out
=
torch
.
empty
(
hidden_states_fp8_shape
,
device
=
hidden_states_fp8_device
,
dtype
=
torch
.
bfloat16
,
)
ep_gather
(
down_output
,
topk_idx
,
topk_weights
,
output_index
,
gather_out
)
return
gather_out
...
...
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
hidden_states_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
)
dispose_tensor
(
hidden_states_fp8
[
0
])
# Act
down_input
=
torch
.
empty
(
...
...
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
scale_block_size
,
masked_m
,
)
del
gateup_output
# GroupGemm-1
n
=
self
.
w2_weight
.
size
(
1
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f194e14f
...
...
@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
(
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
...
...
python/sglang/srt/utils.py
View file @
f194e14f
...
...
@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg):
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
msg
)
def
dispose_tensor
(
x
:
torch
.
Tensor
):
x
.
set_
(
torch
.
empty
((
0
,),
device
=
x
.
device
,
dtype
=
x
.
dtype
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment