Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58a36508
Commit
58a36508
authored
Mar 06, 2026
by
wujl5
Browse files
perf:Deepseek v2模型增加rmsQuant和siluMulQuant融合
parent
7826240b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
217 additions
and
23 deletions
+217
-23
vllm/envs.py
vllm/envs.py
+2
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+39
-8
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
+8
-1
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
...executor/layers/fused_moe/unquantized_fused_moe_method.py
+1
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+25
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+6
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+4
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+128
-11
No files found.
vllm/envs.py
View file @
58a36508
...
...
@@ -1808,7 +1808,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"USE_FUSED_SILU_MUL_QUANT"
,
"0"
))),
lambda
:
(
os
.
getenv
(
"USE_FUSED_SILU_MUL_QUANT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT"
:
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
58a36508
...
...
@@ -6,7 +6,7 @@ import os
from
collections.abc
import
Callable
,
Iterable
from
contextlib
import
nullcontext
from
enum
import
Enum
from
typing
import
Literal
,
cast
,
get_args
,
overload
from
typing
import
Literal
,
cast
,
get_args
,
overload
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -1669,6 +1669,8 @@ class FusedMoE(CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
og_hidden_states
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
og_hidden_states
:
...
...
@@ -1720,7 +1722,9 @@ class FusedMoE(CustomOp):
)
else
:
shared_output
,
fused_output
=
torch
.
ops
.
vllm
.
moe_forward_shared
(
hidden_states
,
router_logits
,
encode_layer_name
()
hidden_states
,
router_logits
,
encode_layer_name
(),
i_q
=
i_q
,
i_s
=
i_s
)
return
(
reduce_output
(
shared_output
)[...,
:
og_hidden_states
],
...
...
@@ -1737,8 +1741,10 @@ class FusedMoE(CustomOp):
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
forward_native
(
hidden_states
,
router_logits
)
return
self
.
forward_native
(
hidden_states
,
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
def
forward_impl_chunked
(
self
,
...
...
@@ -1880,6 +1886,8 @@ class FusedMoE(CustomOp):
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
use_fused_gate
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
quant_method
is
not
None
...
...
@@ -2004,13 +2012,25 @@ class FusedMoE(CustomOp):
if
self
.
capture
is
not
None
:
self
.
capture
(
topk_ids
)
final_hidden_states
=
self
.
quant_method
.
apply
(
if
envs
.
USE_FUSED_RMS_QUANT
:
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
x
,
# The type signture of this is wrong due to the hack.
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
use_nn_moe
=
self
.
use_nn_moe
,
i_q
=
i_q
,
i_s
=
i_s
)
else
:
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
x
,
# The type signture of this is wrong due to the hack.
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
use_nn_moe
=
self
.
use_nn_moe
,
use_nn_moe
=
self
.
use_nn_moe
)
if
has_separate_shared_experts
:
...
...
@@ -2133,16 +2153,20 @@ def moe_forward(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
self
=
get_layer_from_name
(
layer_name
)
assert
self
.
shared_experts
is
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -2160,16 +2184,23 @@ def moe_forward_shared(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
=
get_layer_from_name
(
layer_name
)
assert
self
.
shared_experts
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
if
envs
.
USE_FUSED_RMS_QUANT
:
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
else
:
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
def
moe_forward_shared_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shared_out
=
torch
.
empty_like
(
hidden_states
)
fused_out
=
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/fused_moe/shared_fused_moe.py
View file @
58a36508
...
...
@@ -60,10 +60,13 @@ class SharedFusedMoE(FusedMoE):
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
not
self
.
use_overlapped
:
if
self
.
_shared_experts
is
not
None
:
shared_out
=
self
.
_shared_experts
(
hidden_states
)
shared_out
=
self
.
_shared_experts
(
hidden_states
,
i_q
=
iqis
[
0
]
if
iqis
is
not
None
else
None
,
i_s
=
iqis
[
1
]
if
iqis
is
not
None
else
None
)
# Reduce shared expert outputs if necessary, since the MLP
# should have been created with reduce_results=False.
...
...
@@ -79,11 +82,15 @@ class SharedFusedMoE(FusedMoE):
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
i_q
=
iqis
[
0
]
if
iqis
is
not
None
else
None
,
i_s
=
iqis
[
1
]
if
iqis
is
not
None
else
None
,
)
else
:
shared_out
,
fused_out
=
super
().
forward
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
i_q
=
iqis
[
0
]
if
iqis
is
not
None
else
None
,
i_s
=
iqis
[
1
]
if
iqis
is
not
None
else
None
,
)
# ensure early TP reduction of shared expert outputs when required
if
(
...
...
vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
View file @
58a36508
...
...
@@ -370,7 +370,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_nn_moe
:
bool
|
None
=
False
,
**
_
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
self
.
forward
(
layer
=
layer
,
...
...
vllm/model_executor/layers/linear.py
View file @
58a36508
...
...
@@ -711,6 +711,31 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def
forward
(
self
,
input_
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
Parameter
|
None
]:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
assert
self
.
quant_method
is
not
None
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
input_quant_args
=
iqis
)
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
and
self
.
tp_size
>
1
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
else
:
output
=
output_parallel
if
not
self
.
return_bias
:
return
output
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
def
__init__
(
self
,
...
...
vllm/model_executor/layers/mla.py
View file @
58a36508
...
...
@@ -8,6 +8,7 @@ from vllm.attention.layer import MLAAttention
from
vllm.config
import
CacheConfig
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm
import
envs
@
dataclass
...
...
@@ -115,6 +116,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
llama_4_scaling
:
torch
.
Tensor
|
None
=
None
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
q_c
=
None
kv_lora
=
None
...
...
@@ -129,7 +131,10 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
assert
self
.
q_b_proj
is
not
None
,
(
"q_b_proj is required when q_lora_rank is not None"
)
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
)[
0
]
if
envs
.
USE_FUSED_RMS_QUANT
and
iqis
is
not
None
:
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
,
iqis
=
iqis
)[
0
]
else
:
qkv_lora
=
self
.
fused_qkv_a_proj
(
hidden_states
)[
0
]
q_c
,
kv_lora
=
qkv_lora
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
58a36508
...
...
@@ -1255,6 +1255,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -1271,6 +1273,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
use_nn_moe
=
use_nn_moe
,
i_q
=
i_q
,
i_s
=
i_s
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
58a36508
...
...
@@ -398,6 +398,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
i_q
:
torch
.
Tensor
|
None
=
None
,
i_s
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -420,6 +422,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
i_q
=
i_q
,
i_s
=
i_s
)
def
select_gemm_impl
(
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
58a36508
...
...
@@ -94,6 +94,8 @@ from .utils import (
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.model_executor.layers.layernorm
import
FusedRMSNormQuant
logger
=
init_logger
(
__name__
)
...
...
@@ -169,6 +171,7 @@ class DeepseekAttention(nn.Module):
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
...
...
@@ -218,10 +221,23 @@ class DeepseekV2MLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
...
...
@@ -334,7 +350,9 @@ class DeepseekV2MoE(nn.Module):
else
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -528,6 +546,7 @@ class DeepseekV2Attention(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
llama_4_scaling
:
torch
.
Tensor
|
None
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
...
...
@@ -907,8 +926,9 @@ class DeepseekV2MLAAttention(nn.Module):
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
llama_4_scaling
:
torch
.
Tensor
|
None
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
return
self
.
mla_attn
(
positions
,
hidden_states
,
llama_4_scaling
)
return
self
.
mla_attn
(
positions
,
hidden_states
,
llama_4_scaling
,
iqis
=
iqis
)
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
...
...
@@ -989,13 +1009,91 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
not
envs
.
USE_FUSED_RMS_QUANT
:
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
input_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
FusedRMSNormQuant
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
getattr
(
config
,
"routed_scaling_factor"
,
1.0
)
def
forward
(
def
forward_RQ
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
llama_4_scaling
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
assert
self
.
input_layernorm
.
has_weight
is
True
if
residual
is
None
:
residual
=
hidden_states
.
clone
()
i_q
,
i_s
,
_
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
None
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
residual_fix_overflow
=
True
else
:
i_q
,
i_s
,
residual
=
self
.
input_layernorm
(
x
=
hidden_states
,
residual
=
residual
,
quant_dtype
=
torch
.
int8
,
update_input
=
False
)
attn_kwargs
=
{
"positions"
:
positions
,
"hidden_states"
:
hidden_states
,
"iqis"
:
(
i_q
,
i_s
)
}
if
not
self
.
use_mha
:
attn_kwargs
[
"llama_4_scaling"
]
=
llama_4_scaling
hidden_states
=
self
.
self_attn
(
**
attn_kwargs
)
if
(
not
isinstance
(
self
.
self_attn
,
DeepseekAttention
)
and
hidden_states
.
dtype
==
torch
.
float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.0
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.0
/
self
.
routed_scaling_factor
# Fully Connected
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
assert
self
.
post_attention_layernorm
.
has_weight
is
True
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
residual
=
residual
,
quant_dtype
=
torch
.
int8
,
update_input
=
update_hs
)
new_resi
=
residual
hidden_states
=
self
.
mlp
(
hidden_states
,
# iqis=(_i_q, _i_s) # TODO:wjl
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.0
/
self
.
routed_scaling_factor
return
hidden_states
,
new_resi
def
forward_default
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -1048,6 +1146,25 @@ class DeepseekV2DecoderLayer(nn.Module):
return
hidden_states
,
residual
def
choose_forward
(
self
):
if
envs
.
USE_FUSED_RMS_QUANT
:
return
self
.
forward_RQ
else
:
return
self
.
forward_default
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
,
llama_4_scaling
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
forward_func
=
self
.
choose_forward
()
return
forward_func
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
llama_4_scaling
=
llama_4_scaling
)
@
support_torch_compile
class
DeepseekV2Model
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment