Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1fb76ebb
Unverified
Commit
1fb76ebb
authored
Jun 07, 2025
by
Yineng Zhang
Committed by
GitHub
Jun 07, 2025
Browse files
Revert "Fuse routed scaling factor in topk_reduce kernel (#6220)" (#6968)
parent
c2c4f57f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
9 additions
and
331 deletions
+9
-331
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
+0
-199
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+8
-124
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+0
-1
python/sglang/srt/layers/quantization/blockwise_int8.py
python/sglang/srt/layers/quantization/blockwise_int8.py
+0
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+0
-1
python/sglang/srt/layers/quantization/moe_wna16.py
python/sglang/srt/layers/quantization/moe_wna16.py
+0
-1
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+0
-1
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+0
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
No files found.
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
deleted
100644 → 0
View file @
c2c4f57f
import
torch
import
triton
import
triton.language
as
tl
from
triton.testing
import
do_bench
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
def
moe_sum_reduce
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
1
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
def
compute_sum_scaled_baseline
(
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
routed_scaling_factor
:
float
)
->
torch
.
Tensor
:
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
return
out
@
torch
.
compile
def
compute_sum_scaled_compiled
(
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
routed_scaling_factor
:
float
)
->
torch
.
Tensor
:
torch
.
sum
(
x
*
routed_scaling_factor
,
dim
=
1
,
out
=
out
)
return
out
def
get_benchmark
():
num_tokens_range
=
[
2
**
i
for
i
in
range
(
0
,
13
)]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
num_tokens_range
,
line_arg
=
"version"
,
line_vals
=
[
"baseline"
,
"compiled"
,
"triton"
],
line_names
=
[
"Original"
,
"TorchCompile"
,
"TritonKernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"sum_scaled_performance"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
version
):
topk
=
9
hidden_size
=
4096
dtype
=
torch
.
bfloat16
scaling_factor
=
0.3
x
=
torch
.
randn
(
num_tokens
,
topk
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
# Warmup
for
_
in
range
(
3
):
if
version
==
"baseline"
:
compute_sum_scaled_baseline
(
x
,
out
,
scaling_factor
)
elif
version
==
"compiled"
:
compute_sum_scaled_compiled
(
x
,
out
,
scaling_factor
)
else
:
moe_sum_reduce
(
x
,
out
,
scaling_factor
)
# Benchmark
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
version
==
"baseline"
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
compute_sum_scaled_baseline
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
,
)
elif
version
==
"compiled"
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
compute_sum_scaled_compiled
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
do_bench
(
lambda
:
moe_sum_reduce
(
x
,
out
,
scaling_factor
),
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
def
verify_correctness
(
num_tokens
=
1024
):
x
=
torch
.
randn
(
num_tokens
,
9
,
4096
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
scaling_factor
=
0.3
out_baseline
=
torch
.
empty_like
(
x
[:,
0
])
compute_sum_scaled_baseline
(
x
,
out_baseline
,
scaling_factor
)
out_compiled
=
torch
.
empty_like
(
out_baseline
)
compute_sum_scaled_compiled
(
x
,
out_compiled
,
scaling_factor
)
out_triton
=
torch
.
empty_like
(
out_baseline
)
moe_sum_reduce
(
x
,
out_triton
,
scaling_factor
)
if
torch
.
allclose
(
out_baseline
,
out_compiled
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
out_baseline
,
out_triton
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
print
(
f
"Baseline vs Compiled:
{
(
out_baseline
-
out_compiled
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Baseline vs Triton:
{
(
out_baseline
-
out_triton
).
abs
().
max
().
item
()
}
"
)
if
__name__
==
"__main__"
:
print
(
"Running correctness verification..."
)
verify_correctness
()
print
(
"
\n
Running performance benchmark..."
)
benchmark
=
get_benchmark
()
benchmark
.
run
(
print_data
=
True
,
# save_path="./configs/benchmark_ops/sum_scaled/"
)
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
1fb76ebb
...
@@ -1155,7 +1155,6 @@ def inplace_fused_experts(
...
@@ -1155,7 +1155,6 @@ def inplace_fused_experts(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
fused_experts_impl
(
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1178,8 +1177,6 @@ def inplace_fused_experts(
...
@@ -1178,8 +1177,6 @@ def inplace_fused_experts(
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
False
,
routed_scaling_factor
,
)
)
...
@@ -1203,7 +1200,6 @@ def inplace_fused_experts_fake(
...
@@ -1203,7 +1200,6 @@ def inplace_fused_experts_fake(
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
pass
pass
...
@@ -1237,7 +1233,6 @@ def outplace_fused_experts(
...
@@ -1237,7 +1233,6 @@ def outplace_fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
return
fused_experts_impl
(
hidden_states
,
hidden_states
,
...
@@ -1261,7 +1256,6 @@ def outplace_fused_experts(
...
@@ -1261,7 +1256,6 @@ def outplace_fused_experts(
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
...
@@ -1286,7 +1280,6 @@ def outplace_fused_experts_fake(
...
@@ -1286,7 +1280,6 @@ def outplace_fused_experts_fake(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1321,9 +1314,7 @@ def fused_experts(
...
@@ -1321,9 +1314,7 @@ def fused_experts(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
if
inplace
:
if
inplace
:
assert
not
no_combine
,
"no combine + inplace makes no sense"
assert
not
no_combine
,
"no combine + inplace makes no sense"
torch
.
ops
.
sglang
.
inplace_fused_experts
(
torch
.
ops
.
sglang
.
inplace_fused_experts
(
...
@@ -1346,7 +1337,6 @@ def fused_experts(
...
@@ -1346,7 +1337,6 @@ def fused_experts(
a1_scale
,
a1_scale
,
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
routed_scaling_factor
,
)
)
return
hidden_states
return
hidden_states
else
:
else
:
...
@@ -1371,102 +1361,9 @@ def fused_experts(
...
@@ -1371,102 +1361,9 @@ def fused_experts(
a2_scale
,
a2_scale
,
block_shape
,
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@
triton
.
jit
def
_moe_sum_reduce_kernel
(
input_ptr
,
input_stride_0
,
input_stride_1
,
input_stride_2
,
output_ptr
,
output_stride_0
,
output_stride_1
,
token_num
:
int
,
topk_num
:
int
,
hidden_dim
:
int
,
routed_scaling_factor
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DIM
:
tl
.
constexpr
,
NUM_STAGE
:
tl
.
constexpr
,
):
input_stride_0
=
tl
.
cast
(
input_stride_0
,
dtype
=
tl
.
int64
)
input_stride_1
=
tl
.
cast
(
input_stride_1
,
dtype
=
tl
.
int64
)
output_stride_0
=
tl
.
cast
(
output_stride_0
,
dtype
=
tl
.
int64
)
token_block_id
=
tl
.
program_id
(
0
)
dim_block_id
=
tl
.
program_id
(
1
)
token_start
=
token_block_id
*
BLOCK_M
token_end
=
min
((
token_block_id
+
1
)
*
BLOCK_M
,
token_num
)
dim_start
=
dim_block_id
*
BLOCK_DIM
dim_end
=
min
((
dim_block_id
+
1
)
*
BLOCK_DIM
,
hidden_dim
)
offs_dim
=
dim_start
+
tl
.
arange
(
0
,
BLOCK_DIM
)
for
token_index
in
range
(
token_start
,
token_end
):
accumulator
=
tl
.
zeros
((
BLOCK_DIM
,),
dtype
=
tl
.
float32
)
input_t_ptr
=
input_ptr
+
token_index
*
input_stride_0
+
offs_dim
for
i
in
tl
.
range
(
0
,
topk_num
,
num_stages
=
NUM_STAGE
):
tmp
=
tl
.
load
(
input_t_ptr
+
i
*
input_stride_1
,
mask
=
offs_dim
<
dim_end
,
other
=
0.0
)
accumulator
+=
tmp
accumulator
=
accumulator
*
routed_scaling_factor
store_t_ptr
=
output_ptr
+
token_index
*
output_stride_0
+
offs_dim
tl
.
store
(
store_t_ptr
,
accumulator
.
to
(
input_ptr
.
dtype
.
element_ty
),
mask
=
offs_dim
<
dim_end
,
)
)
def
moe_sum_reduce_triton
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
routed_scaling_factor
:
float
):
assert
input
.
is_contiguous
()
assert
output
.
is_contiguous
()
token_num
,
topk_num
,
hidden_dim
=
input
.
shape
assert
output
.
shape
[
0
]
==
token_num
and
output
.
shape
[
1
]
==
hidden_dim
BLOCK_M
=
1
BLOCK_DIM
=
2048
NUM_STAGE
=
1
num_warps
=
8
grid
=
(
triton
.
cdiv
(
token_num
,
BLOCK_M
),
triton
.
cdiv
(
hidden_dim
,
BLOCK_DIM
),
)
_moe_sum_reduce_kernel
[
grid
](
input
,
*
input
.
stride
(),
output
,
*
output
.
stride
(),
token_num
=
token_num
,
topk_num
=
topk_num
,
hidden_dim
=
hidden_dim
,
routed_scaling_factor
=
routed_scaling_factor
,
BLOCK_M
=
BLOCK_M
,
BLOCK_DIM
=
BLOCK_DIM
,
NUM_STAGE
=
NUM_STAGE
,
num_warps
=
num_warps
,
)
return
@
torch
.
compile
def
moe_sum_reduce_torch_compile
(
x
,
out
,
routed_scaling_factor
):
torch
.
sum
(
x
,
dim
=
1
,
out
=
out
)
out
.
mul_
(
routed_scaling_factor
)
def
fused_experts_impl
(
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -1489,7 +1386,6 @@ def fused_experts_impl(
...
@@ -1489,7 +1386,6 @@ def fused_experts_impl(
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
padded_size
=
padding_size
padded_size
=
padding_size
if
(
if
(
...
@@ -1666,9 +1562,6 @@ def fused_experts_impl(
...
@@ -1666,9 +1562,6 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
if
routed_scaling_factor
is
None
:
routed_scaling_factor
=
1.0
if
no_combine
:
if
no_combine
:
pass
pass
elif
_is_hip
:
elif
_is_hip
:
...
@@ -1677,28 +1570,20 @@ def fused_experts_impl(
...
@@ -1677,28 +1570,20 @@ def fused_experts_impl(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
)
else
:
else
:
if
topk_ids
.
shape
[
1
]
==
1
and
routed_scaling_factor
==
1.0
:
if
topk_ids
.
shape
[
1
]
==
1
:
pass
# we write directly into out_hidden_states
pass
# we write directly into out_hidden_states
elif
topk_ids
.
shape
[
1
]
==
2
and
routed_scaling_factor
==
1.0
:
elif
topk_ids
.
shape
[
1
]
==
2
:
torch
.
add
(
torch
.
add
(
intermediate_cache3
[:,
0
],
intermediate_cache3
[:,
0
],
intermediate_cache3
[:,
1
],
intermediate_cache3
[:,
1
],
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
).
squeeze
(
dim
=
1
)
).
squeeze
(
dim
=
1
)
else
:
elif
topk_ids
.
shape
[
1
]
>
2
:
# According to micro benchmark results, torch.compile can get better performance for small token.
torch
.
sum
(
if
tokens_in_chunk
<=
32
:
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
moe_sum_reduce_torch_compile
(
dim
=
1
,
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
routed_scaling_factor
,
)
else
:
moe_sum_reduce_triton
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
routed_scaling_factor
,
)
return
out_hidden_states
return
out_hidden_states
...
@@ -1810,5 +1695,4 @@ def fused_moe(
...
@@ -1810,5 +1695,4 @@ def fused_moe(
a2_scale
=
a2_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
1fb76ebb
...
@@ -225,7 +225,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -225,7 +225,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
def
forward_cpu
(
def
forward_cpu
(
...
...
python/sglang/srt/layers/quantization/blockwise_int8.py
View file @
1fb76ebb
...
@@ -411,5 +411,4 @@ class BlockInt8MoEMethod:
...
@@ -411,5 +411,4 @@ class BlockInt8MoEMethod:
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
1fb76ebb
...
@@ -317,7 +317,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -317,7 +317,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
1fb76ebb
...
@@ -1030,7 +1030,6 @@ class Fp8MoEMethod:
...
@@ -1030,7 +1030,6 @@ class Fp8MoEMethod:
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
block_shape
=
self
.
quant_config
.
weight_block_size
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
def
maybe_apply_hip_fused_experts
(
def
maybe_apply_hip_fused_experts
(
...
...
python/sglang/srt/layers/quantization/moe_wna16.py
View file @
1fb76ebb
...
@@ -388,7 +388,6 @@ class MoeWNA16Method:
...
@@ -388,7 +388,6 @@ class MoeWNA16Method:
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
w2_zp
=
layer
.
w2_qzeros
if
has_zp
else
None
,
block_shape
=
[
0
,
layer
.
group_size
],
block_shape
=
[
0
,
layer
.
group_size
],
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
@
staticmethod
@
staticmethod
...
...
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
1fb76ebb
...
@@ -328,5 +328,4 @@ class W8A8FP8MoEMethod:
...
@@ -328,5 +328,4 @@ class W8A8FP8MoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
1fb76ebb
...
@@ -268,5 +268,4 @@ class W8A8Int8MoEMethod:
...
@@ -268,5 +268,4 @@ class W8A8Int8MoEMethod:
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
)
)
python/sglang/srt/models/deepseek_v2.py
View file @
1fb76ebb
...
@@ -346,7 +346,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -346,7 +346,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
final_hidden_states
+
shared_output
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment