Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
515ef4fa
"git@developer.sourcefind.cn:OpenDAS/d2go.git" did not exist on "bb34a37569d0de100e2f1919b491ec5c7b42362b"
Unverified
Commit
515ef4fa
authored
Jun 08, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jun 07, 2025
Browse files
Fuse routed scaling factor in topk_reduce kernel (#6220)
parent
f5599ef1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
331 additions
and
9 deletions
+331
-9
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
+199
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+124
-8
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-0
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+1
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+1
-0
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+1
-0
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+1
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
No files found.
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
0 → 100644
View file @
515ef4fa
import
torch
import
triton
import
triton.language
as
tl
from
triton.testing
import
do_bench
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
def
moe_sum_reduce
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
1
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
def
compute_sum_scaled_baseline
(
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
routed_scaling_factor
:
float
)
->
torch
.
Tensor
:
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
return
out
@
torch
.
compile
def
compute_sum_scaled_compiled
(
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
routed_scaling_factor
:
float
)
->
torch
.
Tensor
:
torch
.
sum
(
x
*
routed_scaling_factor
,
dim
=
1
,
out
=
out
)
return
out
def
get_benchmark
():
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
13
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
num_tokens_range
,
line_arg
=
"version"
,
line_vals
=
[
"baseline"
,
"compiled"
,
"triton"
],
line_names
=
[
"Original"
,
"TorchCompile"
,
"TritonKernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"sum_scaled_performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
version
):
topk
=
9
hidden_size
=
4096
dtype
=
torch
.
bfloat16
scaling_factor
=
0.3
x
=
torch
.
randn
(
num_tokens
,
topk
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
3
):
if
version
==
"baseline"
:
compute_sum_scaled_baseline
(
x
,
out
,
scaling_factor
)
elif
version
==
"compiled"
:
compute_sum_scaled_compiled
(
x
,
out
,
scaling_factor
)
else
:
moe_sum_reduce
(
x
,
out
,
scaling_factor
)
# Benchmark
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
version
==
"baseline"
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
compute_sum_scaled_baseline
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
,
)
elif
version
==
"compiled"
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
compute_sum_scaled_compiled
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
moe_sum_reduce
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
def
verify_correctness
(
num_tokens
=
1024
):
x
=
torch
.
randn
(
num_tokens
,
9
,
4096
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
scaling_factor
=
0.3
out_baseline
=
torch
.
empty_like
(
x
[:,
0
])
compute_sum_scaled_baseline
(
x
,
out_baseline
,
scaling_factor
)
out_compiled
=
torch
.
empty_like
(
out_baseline
)
compute_sum_scaled_compiled
(
x
,
out_compiled
,
scaling_factor
)
out_triton
=
torch
.
empty_like
(
out_baseline
)
moe_sum_reduce
(
x
,
out_triton
,
scaling_factor
)
if
torch
.
allclose
(
out_baseline
,
out_compiled
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
out_baseline
,
out_triton
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
print
(
f
"Baseline vs Compiled:
{
(
out_baseline
-
out_compiled
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Baseline vs Triton:
{
(
out_baseline
-
out_triton
).
abs
().
max
().
item
()
}
"
)
if
__name__
==
"__main__"
:
print
(
"Running correctness verification..."
)
verify_correctness
()
print
(
"
\n
Running performance benchmark..."
)
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
# save_path="./configs/benchmark_ops/sum_scaled/"
)
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
515ef4fa
...
@@ -1155,6 +1155,7 @@ def inplace_fused_experts(
...
@@ -1155,6 +1155,7 @@ def inplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
fused_experts_impl
(
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1177,6 +1178,8 @@ def inplace_fused_experts(
...
@@ -1177,6 +1178,8 @@ def inplace_fused_experts(
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
False
,
routed_scaling_factor
,
)
)
...
@@ -1200,6 +1203,7 @@ def inplace_fused_experts_fake(
...
@@ -1200,6 +1203,7 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
pass
pass
...
@@ -1233,6 +1237,7 @@ def outplace_fused_experts(
...
@@ -1233,6 +1237,7 @@ def outplace_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
return
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1256,6 +1261,7 @@ def outplace_fused_experts(
...
@@ -1256,6 +1261,7 @@ def outplace_fused_experts(
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
...
@@ -1280,6 +1286,7 @@ def outplace_fused_experts_fake(
...
@@ -1280,6 +1286,7 @@ def outplace_fused_experts_fake(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1314,7 +1321,9 @@ def fused_experts(
...
@@ -1314,7 +1321,9 @@ def fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
if
inplace
:
if
inplace
:
assert
not
no_combine
,
"no combine + inplace makes no sense"
assert
not
no_combine
,
"no combine + inplace makes no sense"
torch
.
ops
.
sglang
.
inplace_fused_experts
(
torch
.
ops
.
sglang
.
inplace_fused_experts
(
...
@@ -1337,6 +1346,7 @@ def fused_experts(
...
@@ -1337,6 +1346,7 @@ def fused_experts(
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
routed_scaling_factor
,
)
)
return
hidden_states
return
hidden_states
else
:
else
:
...
@@ -1361,9 +1371,102 @@ def fused_experts(
...
@@ -1361,9 +1371,102 @@ def fused_experts(
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
)
def
moe_sum_reduce_triton
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
1
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
@
torch
.
compile
def
moe_sum_reduce_torch_compile
(
x
,
out
,
routed_scaling_factor
):
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
def
fused_experts_impl
(
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -1386,6 +1489,7 @@ def fused_experts_impl(
...
@@ -1386,6 +1489,7 @@ def fused_experts_impl(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
padded_size
=
padding_size
padded_size
=
padding_size
if
(
if
(
...
@@ -1562,6 +1666,9 @@ def fused_experts_impl(
...
@@ -1562,6 +1666,9 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
if
routed_scaling_factor
is
None
:
routed_scaling_factor
=
1.0
if
no_combine
:
if
no_combine
:
pass
pass
elif
_is_hip
:
elif
_is_hip
:
...
@@ -1570,20 +1677,28 @@ def fused_experts_impl(
...
@@ -1570,20 +1677,28 @@ def fused_experts_impl(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
)
else
:
else
:
if
topk_ids
.
shape
[
1
]
==
1
:
if
topk_ids
.
shape
[
1
]
==
1
and
routed_scaling_factor
==
1.0
:
pass
# we write directly into out_hidden_states
pass
# we write directly into out_hidden_states
elif
topk_ids
.
shape
[
1
]
==
2
:
elif
topk_ids
.
shape
[
1
]
==
2
and
routed_scaling_factor
==
1.0
:
torch
.
add
(
torch
.
add
(
intermediate_cache3
[:,
0
],
intermediate_cache3
[:,
0
],
intermediate_cache3
[:,
1
],
intermediate_cache3
[:,
1
],
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
).
squeeze
(
dim
=
1
)
).
squeeze
(
dim
=
1
)
elif
topk_ids
.
shape
[
1
]
>
2
:
else
:
torch
.
sum
(
# According to micro benchmark results, torch.compile can get better performance for small token.
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
if
tokens_in_chunk
<=
32
:
dim
=
1
,
moe_sum_reduce_torch_compile
(
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
)
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
routed_scaling_factor
,
)
else
:
moe_sum_reduce_triton
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
routed_scaling_factor
,
)
return
out_hidden_states
return
out_hidden_states
...
@@ -1695,4 +1810,5 @@ def fused_moe(
...
@@ -1695,4 +1810,5 @@ def fused_moe(
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
515ef4fa
...
@@ -225,6 +225,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -225,6 +225,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
515ef4fa
...
@@ -411,4 +411,5 @@ class BlockInt8MoEMethod:
...
@@ -411,4 +411,5 @@ class BlockInt8MoEMethod:
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
515ef4fa
...
@@ -317,6 +317,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -317,6 +317,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
515ef4fa
...
@@ -1030,6 +1030,7 @@ class Fp8MoEMethod:
...
@@ -1030,6 +1030,7 @@ class Fp8MoEMethod:
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
def
maybe_apply_hip_fused_experts
(
def
maybe_apply_hip_fused_experts
(
...
...
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
515ef4fa
...
@@ -388,6 +388,7 @@ class MoeWNA16Method:
...
@@ -388,6 +388,7 @@ class MoeWNA16Method:
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
],
block_shape
=
[
0
,
layer
.
group_size
],
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
@
staticmethod
@
staticmethod
...
...
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
515ef4fa
...
@@ -328,4 +328,5 @@ class W8A8FP8MoEMethod:
...
@@ -328,4 +328,5 @@ class W8A8FP8MoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
515ef4fa
...
@@ -268,4 +268,5 @@ class W8A8Int8MoEMethod:
...
@@ -268,4 +268,5 @@ class W8A8Int8MoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/models/deepseek_v2.py
View file @
515ef4fa
...
@@ -346,7 +346,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -346,7 +346,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment