Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3dc7dc6c
"vscode:/vscode.git/clone" did not exist on "0fecb2ddb9830b03e69bb8fe77a4596a8b7edf66"
Commit
3dc7dc6c
authored
Nov 14, 2025
by
maxiao1
Browse files
算子融合
parent
842b423a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
190 additions
and
62 deletions
+190
-62
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+8
-0
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+32
-0
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+26
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+16
-1
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+10
-7
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+34
-4
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+64
-46
No files found.
python/sglang/srt/environ.py
View file @
3dc7dc6c
...
@@ -167,6 +167,14 @@ class Envs:
...
@@ -167,6 +167,14 @@ class Envs:
# DCU Lightop
# DCU Lightop
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD
=
EnvBool
(
False
)
SGLANG_USE_OPT_CAT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_RMS_QUANT
=
EnvBool
(
False
)
SGLANG_USE_FUSED_SILU_MUL_QUANT
=
EnvBool
(
False
)
# Quantization
# Quantization
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
SGLANG_CPU_QUANTIZATION
=
EnvBool
(
False
)
SGLANG_CPU_QUANTIZATION
=
EnvBool
(
False
)
...
...
python/sglang/srt/layers/attention/lightop_concat.py
0 → 100644
View file @
3dc7dc6c
from
__future__
import
annotations
import
warnings
import
torch
from
sglang.srt.utils
import
get_bool_env_var
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_USE_OPT_CAT
:
try
:
from
lightop
import
ds_cat
# type: ignore
except
ImportError
:
# pragma: no cover
ds_cat
=
None
warnings
.
warn
(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else
:
ds_cat
=
None
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
mode
=
0
if
dim
!=
0
:
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
assert
False
,
"not support"
\ No newline at end of file
python/sglang/srt/layers/linear.py
View file @
3dc7dc6c
...
@@ -44,6 +44,18 @@ _is_hip = is_hip()
...
@@ -44,6 +44,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
_disable_hip_linear_quant
=
_is_hip
and
get_bool_env_var
(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
)
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
if
_use_fused_rms_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_faster_rmsquant
except
Exception
as
e
:
print
(
f
"Error: Import fused rmsquant error:
{
e
}
"
)
if
_use_fused_silu_mul_quant
:
try
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
except
Exception
as
e
:
print
(
f
"Error: Import fused silu_mul_quant error:
{
e
}
"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1358,7 +1370,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1358,7 +1370,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
# It does not support additional parameters.
param
.
load_row_parallel_weight
(
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
)
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
):
def
forward
(
self
,
input_
,
skip_all_reduce
=
False
,
use_fused_silu_mul_quant
=
False
):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
@@ -1372,9 +1384,19 @@ class RowParallelLinear(LinearBase):
...
@@ -1372,9 +1384,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
if
use_fused_silu_mul_quant
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
xq
,
xs
=
lm_fuse_silu_mul_quant
(
input_parallel
)
sm
.
tag
(
output_parallel
)
silu_quant_args
=
[
xq
,
xs
]
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
,
silu_quant_args
=
silu_quant_args
)
sm
.
tag
(
output_parallel
)
else
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
3dc7dc6c
...
@@ -42,6 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
...
@@ -42,6 +42,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.environ
import
envs
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.two_batch_overlap
import
MaybeTboDeepEPDispatcher
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
...
@@ -58,6 +59,7 @@ if is_flashinfer_available():
...
@@ -58,6 +59,7 @@ if is_flashinfer_available():
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_user_lightop_moe_sum_mul_add
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD"
)
# Try to import FP4 TRTLLM function if flashinfer is available
# Try to import FP4 TRTLLM function if flashinfer is available
...
@@ -221,6 +223,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -221,6 +223,7 @@ class FusedMoE(torch.nn.Module):
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
moe_tp_size
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
self
.
use_presharded_weights
=
use_presharded_weights
self
.
use_presharded_weights
=
use_presharded_weights
# self.global_num_experts = self.num_experts
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
...
@@ -877,9 +880,21 @@ class FusedMoE(torch.nn.Module):
...
@@ -877,9 +880,21 @@ class FusedMoE(torch.nn.Module):
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
)
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
**
kwargs
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
TopKOutput
=
None
,
shared_output
:
torch
.
Tensor
=
None
,
**
kwargs
):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
_user_lightop_moe_sum_mul_add
:
final_hidden_states
=
self
.
quant_method
.
apply_with_shared_output
(
layer
=
self
,
x
=
hidden_states
,
activation
=
getattr
(
self
,
'moe_runner_config'
,
None
)
and
self
.
moe_runner_config
.
activation
or
"silu"
,
shared_output
=
shared_output
,
topk_output
=
topk_output
,
)
if
self
.
reduce_results
and
(
self
.
moe_tp_size
>
1
or
self
.
moe_ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
dispatch_output
=
self
.
dispatcher
.
dispatch
(
dispatch_output
=
self
.
dispatcher
.
dispatch
(
hidden_states
=
hidden_states
,
topk_output
=
topk_output
hidden_states
=
hidden_states
,
topk_output
=
topk_output
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
3dc7dc6c
...
@@ -19,6 +19,9 @@ from vllm.utils import W8a8GetCacheJSON
...
@@ -19,6 +19,9 @@ from vllm.utils import W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
import
os
from
sglang.srt.utils
import
get_bool_env_var
_use_fused_rms_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_RMS_QUANT"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
"""
...
@@ -163,13 +166,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -163,13 +166,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
):
#
if
envs.USE_FUSED_RMS_QUANT
and input_quant_args is not None:
if
_use_fused_rms_quant
and
input_quant_args
is
not
None
:
#
assert len(input_quant_args) == 2
assert
len
(
input_quant_args
)
==
2
#
x_q, x_scale = input_quant_args
x_q
,
x_scale
=
input_quant_args
#
elif
envs.USE_FUSED_SILU_MUL_QUANT
and silu_quant_args is not None:
elif
_use_fused_silu_mul_quant
and
silu_quant_args
is
not
None
:
#
x_q, x_scale = silu_quant_args
x_q
,
x_scale
=
silu_quant_args
#
else:
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
m
=
x_q
.
shape
[
0
]
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
3dc7dc6c
...
@@ -252,6 +252,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -252,6 +252,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe
=
False
,
use_nn_moe
=
False
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
return
StandardCombineInput
(
hidden_states
=
output
)
def
apply_with_shared_output
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
topk_output
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
moe_runner_config
.
num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
)
# def _apply(
# def _apply(
# self,
# self,
# layer: torch.nn.Module,
# layer: torch.nn.Module,
...
@@ -317,9 +350,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -317,9 +350,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# a2_scale=layer.w2_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# use_nn_moe=use_nn_moe,
# )
# )
#
def
apply_ep
(
self
,
def
apply_ep
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -368,4 +398,4 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -368,4 +398,4 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
#config_select_bs=config_select_bs,
#config_select_bs=config_select_bs,
#q_scales=scales
#q_scales=scales
)
)
\ No newline at end of file
python/sglang/srt/models/deepseek_v2.py
View file @
3dc7dc6c
...
@@ -141,6 +141,7 @@ from sglang.srt.utils import (
...
@@ -141,6 +141,7 @@ from sglang.srt.utils import (
make_layers
,
make_layers
,
use_intel_amx_backend
,
use_intel_amx_backend
,
)
)
from
sglang.srt.layers.attention.lightop_concat
import
concat_decode_opt
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -151,8 +152,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
...
@@ -151,8 +152,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_device_sm
=
get_device_sm
()
_device_sm
=
get_device_sm
()
_is_gfx95_supported
=
is_gfx95_supported
()
_is_gfx95_supported
=
is_gfx95_supported
()
_user_lightop_moe_sum_mul_add
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD"
)
_use_fused_silu_mul_quant
=
get_bool_env_var
(
"SGLANG_USE_FUSED_SILU_MUL_QUANT"
)
_use_aiter_gfx95
=
_use_aiter
and
_is_gfx95_supported
_use_aiter_gfx95
=
_use_aiter
and
_is_gfx95_supported
_use_opt_cat_decode
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
if
_use_aiter_gfx95
:
if
_use_aiter_gfx95
:
from
sglang.srt.layers.quantization.quark.utils
import
quark_post_load_weights
from
sglang.srt.layers.quantization.quark.utils
import
quark_post_load_weights
...
@@ -456,10 +459,13 @@ class DeepseekV2MLP(nn.Module):
...
@@ -456,10 +459,13 @@ class DeepseekV2MLP(nn.Module):
x
=
(
x
,
None
,
y
)
x
=
(
x
,
None
,
y
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
if
_use_fused_silu_mul_quant
:
x
,
_
=
self
.
down_proj
(
x
,
_
=
self
.
down_proj
(
gate_up
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
,
use_fused_silu_mul_quant
=
True
)
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
else
:
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
)
return
x
return
x
...
@@ -757,49 +763,58 @@ class DeepseekV2MoE(nn.Module):
...
@@ -757,49 +763,58 @@ class DeepseekV2MoE(nn.Module):
self
.
shared_experts
.
gate_up_proj
self
.
shared_experts
.
gate_up_proj
):
):
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
if
_user_lightop_moe_sum_mul_add
:
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
hidden_states
,
gemm_output_zero_allocator
)
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
,
shared_output
=
shared_output
)
else
:
else
:
shared_output
=
None
if
hidden_states
.
shape
[
0
]
>
0
:
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
if
not
self
.
_fuse_shared_experts_inside_sbo
:
shared_output
=
self
.
_forward_shared_experts
(
if
self
.
_fuse_shared_experts_inside_sbo
:
hidden_states
,
gemm_output_zero_allocator
shared_output
=
None
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
,
gemm_output_zero_allocator
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
def
_forward_shared_experts_and_put_results
():
if
self
.
_fuse_shared_experts_inside_sbo
:
nonlocal
shared_output
shared_output
=
None
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
,
gemm_output_zero_allocator
)
final_hidden_states
=
self
.
experts
(
def
_forward_shared_experts_and_put_results
():
hidden_states
,
nonlocal
shared_output
topk_output
,
shared_output
=
self
.
_forward_shared_experts
(
**
(
hidden_states
,
gemm_output_zero_allocator
dict
(
)
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
final_hidden_states
=
self
.
experts
(
alt_stream
=
self
.
alt_stream
,
hidden_states
,
)
topk_output
,
if
self
.
_fuse_shared_experts_inside_sbo
**
(
else
{}
dict
(
),
forward_shared_experts
=
_forward_shared_experts_and_put_results
,
)
alt_stream
=
self
.
alt_stream
,
if
not
_is_cuda
and
not
_use_aiter
:
)
# fused in biased_grouped_topk so we can skip here
if
self
.
_fuse_shared_experts_inside_sbo
final_hidden_states
*=
self
.
routed_scaling_factor
else
{}
if
shared_output
is
not
None
:
),
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
)
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
if
not
_is_cuda
and
not
_use_aiter
:
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
# fused in biased_grouped_topk so we can skip here
final_hidden_states
=
final_hidden_states_out
final_hidden_states
*=
self
.
routed_scaling_factor
sm
.
tag
(
final_hidden_states
)
if
shared_output
is
not
None
:
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
(
if
(
self
.
tp_size
>
1
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
should_allreduce_fusion
...
@@ -1696,7 +1711,10 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1696,7 +1711,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
rotary_emb
.
is_neox_style
,
self
.
rotary_emb
.
is_neox_style
,
)
)
else
:
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
if
_use_opt_cat_decode
and
q_nope_out
.
shape
[
0
]
<
1024
:
q
=
concat_decode_opt
(
q_nope_out
,
q_pe
,
dim
=
2
)
else
:
q
=
torch
.
cat
([
q_nope_out
,
q_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
k
=
torch
.
cat
([
k_nope
,
k_pe
],
dim
=-
1
)
attn_output
=
self
.
attn_mqa
(
attn_output
=
self
.
attn_mqa
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment