Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f194e14f
Unverified
Commit
f194e14f
authored
May 16, 2025
by
fzyzcjy
Committed by
GitHub
May 15, 2025
Browse files
Reduce MoE memory usage (#6147)
parent
cfc9f9ab
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
75 additions
and
40 deletions
+75
-40
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+10
-2
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+58
-35
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-0
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
f194e14f
...
@@ -3,10 +3,9 @@ from typing import List, Optional
...
@@ -3,10 +3,9 @@ from typing import List, Optional
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
dispose_tensor
,
is_cuda
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
...
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
scale_a
:
torch
.
Tensor
=
None
,
scale_a
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
c_dtype
=
None
,
):
):
assert
weight_column_major
==
True
# TODO: more
assert
weight_column_major
==
True
# TODO: more
if
use_fp8_w8a8
and
block_shape
is
None
:
if
use_fp8_w8a8
and
block_shape
is
None
:
assert
scale_a
is
not
None
and
scale_b
is
not
None
assert
scale_a
is
not
None
and
scale_b
is
not
None
if
block_shape
is
not
None
:
if
block_shape
is
not
None
:
a_original
=
a
assert
len
(
block_shape
)
==
2
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
a
,
scale_a
=
per_token_group_quant_fp8
(
a
,
block_k
)
a
,
scale_a
=
per_token_group_quant_fp8
(
a
,
block_k
)
...
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
...
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
assert
triton
.
cdiv
(
b
.
shape
[
-
2
],
block_n
)
==
scale_b
.
shape
[
-
2
]
assert
triton
.
cdiv
(
b
.
shape
[
-
2
],
block_n
)
==
scale_b
.
shape
[
-
2
]
assert
triton
.
cdiv
(
b
.
shape
[
-
1
],
block_k
)
==
scale_b
.
shape
[
-
1
]
assert
triton
.
cdiv
(
b
.
shape
[
-
1
],
block_k
)
==
scale_b
.
shape
[
-
1
]
dispose_tensor
(
a_original
)
# TODO: adjust config or tune kernel
# TODO: adjust config or tune kernel
# Reduce block size to prevent L40 shared memory overflow.
# Reduce block size to prevent L40 shared memory overflow.
config
=
{
config
=
{
...
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
...
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
m_num_tiles_indptr
,
seg_indptr
,
batch_size
,
config
[
"BLOCK_SIZE_M"
]
m_num_tiles_indptr
,
seg_indptr
,
batch_size
,
config
[
"BLOCK_SIZE_M"
]
)
)
if
c
is
None
:
assert
c_dtype
is
not
None
c
=
torch
.
empty
(
a
.
shape
[
0
],
b
.
shape
[
1
],
device
=
a
.
device
,
dtype
=
c_dtype
)
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
triton
.
cdiv
(
a
.
size
(
0
),
META
[
"BLOCK_SIZE_M"
])
+
batch_size
,
triton
.
cdiv
(
a
.
size
(
0
),
META
[
"BLOCK_SIZE_M"
])
+
batch_size
,
triton
.
cdiv
(
b
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
triton
.
cdiv
(
b
.
size
(
1
),
META
[
"BLOCK_SIZE_N"
]),
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
f194e14f
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
,
Fp8MoEMethod
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.utils
import
DeepEPMode
,
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
DeepEPMode
,
dispose_tensor
,
is_hip
,
set_weight_attrs
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
...
@@ -92,6 +92,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a
:
torch
.
Tensor
=
None
,
scale_a
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
scale_b
:
torch
.
Tensor
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
c_dtype
=
None
,
):
):
if
self
.
use_flashinfer
:
if
self
.
use_flashinfer
:
# TODO: flashinfer
# TODO: flashinfer
...
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
...
@@ -119,6 +120,7 @@ class GroupedGemmRunner(torch.nn.Module):
scale_a
,
scale_a
,
scale_b
,
scale_b
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
c_dtype
=
c_dtype
,
)
)
return
c
return
c
...
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
...
@@ -210,6 +212,10 @@ class EPMoE(torch.nn.Module):
self
.
grouped_gemm_runner
=
None
self
.
grouped_gemm_runner
=
None
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
hidden_states_shape
=
hidden_states
.
shape
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
self
.
grouped_gemm_runner
is
None
:
if
self
.
grouped_gemm_runner
is
None
:
...
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
...
@@ -265,25 +271,21 @@ class EPMoE(torch.nn.Module):
hidden_states
.
shape
[
1
],
hidden_states
.
shape
[
1
],
BLOCK_SIZE
=
512
,
BLOCK_SIZE
=
512
,
)
)
dispose_tensor
(
hidden_states
)
seg_indptr_cur_rank
=
seg_indptr
[
self
.
start_expert_id
:
self
.
end_expert_id
+
2
]
seg_indptr_cur_rank
=
seg_indptr
[
self
.
start_expert_id
:
self
.
end_expert_id
+
2
]
weight_indices_cur_rank
=
torch
.
arange
(
weight_indices_cur_rank
=
torch
.
arange
(
0
,
0
,
self
.
num_experts_per_partition
,
self
.
num_experts_per_partition
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
# GroupGemm-0
# GroupGemm-0
gateup_output
=
torch
.
empty
(
gateup_input
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
gateup_output
=
self
.
grouped_gemm_runner
(
gateup_output
=
self
.
grouped_gemm_runner
(
a
=
gateup_input
,
a
=
gateup_input
,
b
=
self
.
w13_weight
,
b
=
self
.
w13_weight
,
c
=
gateup_output
,
c
=
None
,
c_dtype
=
hidden_states_dtype
,
batch_size
=
self
.
num_experts_per_partition
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr_cur_rank
,
seg_indptr
=
seg_indptr_cur_rank
,
...
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
...
@@ -297,6 +299,7 @@ class EPMoE(torch.nn.Module):
),
),
block_shape
=
self
.
block_shape
,
block_shape
=
self
.
block_shape
,
)
)
del
gateup_input
# Act
# Act
down_input
=
torch
.
empty
(
down_input
=
torch
.
empty
(
...
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
...
@@ -306,14 +309,14 @@ class EPMoE(torch.nn.Module):
dtype
=
(
dtype
=
(
self
.
fp8_dtype
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states
.
dtype
else
hidden_states
_
dtype
),
),
)
)
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
)
)
if
self
.
activation
==
"silu"
:
if
self
.
activation
==
"silu"
:
...
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
...
@@ -340,13 +343,14 @@ class EPMoE(torch.nn.Module):
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
del
gateup_output
# GroupGemm-1
# GroupGemm-1
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
down_input
.
shape
[
0
],
self
.
w2_weight
.
shape
[
1
],
self
.
w2_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
_
dtype
,
)
)
down_output
=
self
.
grouped_gemm_runner
(
down_output
=
self
.
grouped_gemm_runner
(
a
=
down_input
,
a
=
down_input
,
...
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
...
@@ -365,10 +369,13 @@ class EPMoE(torch.nn.Module):
),
),
block_shape
=
self
.
block_shape
,
block_shape
=
self
.
block_shape
,
)
)
del
down_input
# PostReorder
# PostReorder
output
=
torch
.
empty_like
(
hidden_states
)
output
=
torch
.
empty
(
post_reorder_triton_kernel
[(
hidden_states
.
size
(
0
),)](
hidden_states_shape
,
dtype
=
hidden_states_dtype
,
device
=
hidden_states_device
)
post_reorder_triton_kernel
[(
hidden_states_shape
[
0
],)](
down_output
,
down_output
,
output
,
output
,
src2dst
,
src2dst
,
...
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
...
@@ -377,7 +384,7 @@ class EPMoE(torch.nn.Module):
self
.
start_expert_id
,
self
.
start_expert_id
,
self
.
end_expert_id
,
self
.
end_expert_id
,
self
.
top_k
,
self
.
top_k
,
hidden_states
.
size
(
1
)
,
hidden_states
_shape
[
1
]
,
BLOCK_SIZE
=
512
,
BLOCK_SIZE
=
512
,
)
)
return
output
return
output
...
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
...
@@ -881,6 +888,9 @@ class DeepEPMoE(EPMoE):
reorder_topk_ids
:
torch
.
Tensor
,
reorder_topk_ids
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
seg_indptr
:
torch
.
Tensor
,
):
):
hidden_states_dtype
=
hidden_states
.
dtype
hidden_states_device
=
hidden_states
.
device
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
assert
self
.
activation
==
"silu"
assert
self
.
activation
==
"silu"
if
self
.
grouped_gemm_runner
is
None
:
if
self
.
grouped_gemm_runner
is
None
:
...
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
...
@@ -903,18 +913,12 @@ class DeepEPMoE(EPMoE):
)
)
# GroupGemm-0
# GroupGemm-0
gateup_output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
gateup_output
=
self
.
grouped_gemm_runner
(
gateup_output
=
self
.
grouped_gemm_runner
(
a
=
hidden_states
,
a
=
hidden_states
,
b
=
self
.
w13_weight
,
b
=
self
.
w13_weight
,
c
=
gateup_output
,
c
=
None
,
c_dtype
=
hidden_states
.
dtype
,
batch_size
=
self
.
num_experts_per_partition
,
batch_size
=
self
.
num_experts_per_partition
,
weight_column_major
=
True
,
weight_column_major
=
True
,
seg_indptr
=
seg_indptr
,
seg_indptr
=
seg_indptr
,
...
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
...
@@ -928,6 +932,13 @@ class DeepEPMoE(EPMoE):
),
),
block_shape
=
self
.
block_shape
,
block_shape
=
self
.
block_shape
,
)
)
else
:
gateup_output
=
torch
.
empty
(
hidden_states
.
shape
[
0
],
self
.
w13_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
# Act
# Act
down_input
=
torch
.
empty
(
down_input
=
torch
.
empty
(
...
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
...
@@ -937,14 +948,14 @@ class DeepEPMoE(EPMoE):
dtype
=
(
dtype
=
(
self
.
fp8_dtype
self
.
fp8_dtype
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
if
(
self
.
use_fp8_w8a8
and
not
self
.
use_block_quant
)
else
hidden_states
.
dtype
else
hidden_states
_
dtype
),
),
)
)
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
if
self
.
w2_input_scale
is
None
and
not
self
.
use_block_quant
:
self
.
w2_input_scale
=
torch
.
ones
(
self
.
w2_input_scale
=
torch
.
ones
(
self
.
num_experts_per_partition
,
self
.
num_experts_per_partition
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
)
)
if
self
.
activation
==
"silu"
:
if
self
.
activation
==
"silu"
:
...
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
...
@@ -961,12 +972,14 @@ class DeepEPMoE(EPMoE):
else
:
else
:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
=
}
"
)
del
gateup_output
# GroupGemm-1
# GroupGemm-1
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
down_input
.
shape
[
0
],
down_input
.
shape
[
0
],
self
.
w2_weight
.
shape
[
1
],
self
.
w2_weight
.
shape
[
1
],
device
=
hidden_states
.
device
,
device
=
hidden_states
_
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
_
dtype
,
)
)
if
down_input
.
shape
[
0
]
>
0
:
if
down_input
.
shape
[
0
]
>
0
:
down_output
=
self
.
grouped_gemm_runner
(
down_output
=
self
.
grouped_gemm_runner
(
...
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
...
@@ -1007,11 +1020,9 @@ class DeepEPMoE(EPMoE):
N
=
self
.
w13_weight
.
size
(
1
)
N
=
self
.
w13_weight
.
size
(
1
)
scale_block_size
=
128
scale_block_size
=
128
gather_out
=
torch
.
empty_like
(
hidden_states_fp8_shape
=
hidden_states_fp8
.
shape
hidden_states_fp8
,
hidden_states_fp8_device
=
hidden_states_fp8
.
device
device
=
hidden_states_fp8
.
device
,
hidden_states_fp8_dtype
=
hidden_states_fp8
.
dtype
dtype
=
torch
.
bfloat16
,
)
input_tensor
=
[
input_tensor
=
[
torch
.
empty
(
torch
.
empty
(
...
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
...
@@ -1049,16 +1060,18 @@ class DeepEPMoE(EPMoE):
m_indices
,
m_indices
,
output_index
,
output_index
,
)
)
dispose_tensor
(
hidden_states_fp8
)
gateup_output
=
torch
.
empty
(
gateup_output
=
torch
.
empty
(
(
all_tokens
,
N
),
(
all_tokens
,
N
),
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
_
device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
input_tensor
[
1
]
=
tma_align_input_scale
(
input_tensor
[
1
])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
input_tensor
,
self
.
w13_weight_fp8
,
gateup_output
,
m_indices
)
)
del
input_tensor
down_input
=
torch
.
empty
(
down_input
=
torch
.
empty
(
(
(
all_tokens
,
all_tokens
,
...
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
...
@@ -1068,14 +1081,16 @@ class DeepEPMoE(EPMoE):
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
silu_and_mul
(
gateup_output
.
view
(
-
1
,
N
),
down_input
)
del
gateup_output
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
(
all_tokens
,
K
),
(
all_tokens
,
K
),
device
=
hidden_states_fp8
.
device
,
device
=
hidden_states_fp8
_
device
,
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
)
)
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input_fp8
,
down_input_scale
=
sglang_per_token_group_quant_fp8
(
down_input
,
scale_block_size
down_input
,
scale_block_size
)
)
del
down_input
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
down_input_scale
=
tma_align_input_scale
(
down_input_scale
)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
(
(
down_input_fp8
,
down_input_scale
),
(
down_input_fp8
,
down_input_scale
),
...
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
...
@@ -1083,7 +1098,13 @@ class DeepEPMoE(EPMoE):
down_output
,
down_output
,
m_indices
,
m_indices
,
)
)
del
down_input_fp8
,
down_input_scale
gather_out
=
torch
.
empty
(
hidden_states_fp8_shape
,
device
=
hidden_states_fp8_device
,
dtype
=
torch
.
bfloat16
,
)
ep_gather
(
down_output
,
topk_idx
,
topk_weights
,
output_index
,
gather_out
)
ep_gather
(
down_output
,
topk_idx
,
topk_weights
,
output_index
,
gather_out
)
return
gather_out
return
gather_out
...
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1107,6 +1128,7 @@ class DeepEPMoE(EPMoE):
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
m_grouped_gemm_fp8_fp8_bf16_nt_masked
(
hidden_states_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
hidden_states_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
)
)
dispose_tensor
(
hidden_states_fp8
[
0
])
# Act
# Act
down_input
=
torch
.
empty
(
down_input
=
torch
.
empty
(
...
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1135,6 +1157,7 @@ class DeepEPMoE(EPMoE):
scale_block_size
,
scale_block_size
,
masked_m
,
masked_m
,
)
)
del
gateup_output
# GroupGemm-1
# GroupGemm-1
n
=
self
.
w2_weight
.
size
(
1
)
n
=
self
.
w2_weight
.
size
(
1
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f194e14f
...
@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -311,10 +311,10 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
(
final_hidden_states
=
self
.
experts
(
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
hidden_states
=
hidden_states
,
router_logits
=
router_logits
*
self
.
routed_scaling_factor
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
...
python/sglang/srt/utils.py
View file @
f194e14f
...
@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg):
...
@@ -2100,3 +2100,7 @@ def log_info_on_rank0(logger, msg):
if
get_tensor_model_parallel_rank
()
==
0
:
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
msg
)
logger
.
info
(
msg
)
def
dispose_tensor
(
x
:
torch
.
Tensor
):
x
.
set_
(
torch
.
empty
((
0
,),
device
=
x
.
device
,
dtype
=
x
.
dtype
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment