Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
68972532
Commit
68972532
authored
Nov 26, 2025
by
wujl5
Committed by
zhuwenwen
Nov 26, 2025
Browse files
[pref]: DS_v2_w8a8模型融掉moe.quant
parent
7ff48a6c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
93 additions
and
40 deletions
+93
-40
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+25
-9
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+21
-8
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+12
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+5
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+6
-3
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+1
-1
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+1
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+1
-1
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+20
-11
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
68972532
...
@@ -1394,13 +1394,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -1394,13 +1394,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
None
:
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_int4_w4a16
,
use_int4_w4a8
,
per_channel_quant
,
global_num_experts
,
expert_map
,
per_channel_quant
,
global_num_experts
,
expert_map
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
)
block_shape
,
use_nn_moe
,
shared_output
,
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
def
inplace_fused_experts_fake
(
def
inplace_fused_experts_fake
(
...
@@ -1428,7 +1430,9 @@ def inplace_fused_experts_fake(
...
@@ -1428,7 +1430,9 @@ def inplace_fused_experts_fake(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
None
:
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
pass
pass
...
@@ -1466,7 +1470,9 @@ def outplace_fused_experts(
...
@@ -1466,7 +1470,9 @@ def outplace_fused_experts(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
torch
.
Tensor
:
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
activation
,
apply_router_weight_on_input
,
False
,
activation
,
apply_router_weight_on_input
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
use_fp8_w8a8
,
use_int8_w8a8
,
use_int8_w8a16
,
...
@@ -1500,7 +1506,9 @@ def outplace_fused_experts_fake(
...
@@ -1500,7 +1506,9 @@ def outplace_fused_experts_fake(
block_shape
:
Optional
[
List
[
int
]]
=
None
,
block_shape
:
Optional
[
List
[
int
]]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
torch
.
Tensor
:
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
@@ -1559,7 +1567,9 @@ def fused_experts(
...
@@ -1559,7 +1567,9 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
allow_cutlass_block_scaled_grouped_gemm
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
)
->
torch
.
Tensor
:
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
# For now, disable DeepGemm for small N (<= 512) until better
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
# permute/unpermute ops are available.
N
=
w1
.
size
(
1
)
N
=
w1
.
size
(
1
)
...
@@ -1594,6 +1604,7 @@ def fused_experts(
...
@@ -1594,6 +1604,7 @@ def fused_experts(
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
)
topk_ids
=
topk_ids
)
else
:
else
:
# Fused MoE quantization only 4 DS w8a8 now
return
dispatch_fused_experts_func
(
inplace
)(
return
dispatch_fused_experts_func
(
inplace
)(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
w1
=
w1
,
w1
=
w1
,
...
@@ -1619,7 +1630,9 @@ def fused_experts(
...
@@ -1619,7 +1630,9 @@ def fused_experts(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
use_nn_moe
,
use_nn_moe
=
use_nn_moe
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
def
fused_experts_impl
(
def
fused_experts_impl
(
...
@@ -1649,6 +1662,8 @@ def fused_experts_impl(
...
@@ -1649,6 +1662,8 @@ def fused_experts_impl(
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
routed_scaling_factor
:
Optional
[
float
]
=
1.0
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
num_tokens
=
hidden_states
.
size
(
0
)
num_tokens
=
hidden_states
.
size
(
0
)
if
use_nn_moe
:
if
use_nn_moe
:
...
@@ -1695,8 +1710,9 @@ def fused_experts_impl(
...
@@ -1695,8 +1710,9 @@ def fused_experts_impl(
block_shape
=
block_shape
,
block_shape
=
block_shape
,
use_nn_moe
=
False
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
routed_scaling_factor
=
routed_scaling_factor
,
)
i_q
=
i_q
,
i_s
=
i_s
)
elif
use_int4_w4a8
is
True
:
elif
use_int4_w4a8
is
True
:
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
return
fused_experts_impl_w4a8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w1
=
w1
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
68972532
...
@@ -8,7 +8,7 @@ import importlib
...
@@ -8,7 +8,7 @@ import importlib
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
Literal
,
Optional
,
overload
from
typing
import
Callable
,
Literal
,
Optional
,
overload
,
Tuple
,
List
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -1435,14 +1435,19 @@ class FusedMoE(torch.nn.Module):
...
@@ -1435,14 +1435,19 @@ class FusedMoE(torch.nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
# switch to using the moe_forward custom op.
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
assert
i_q
is
None
and
i_s
is
None
,
"moe.quant fused not support TPU now"
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
else
:
else
:
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
return
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
,
shared_output
)
self
.
layer_name
,
shared_output
,
i_q
,
i_s
)
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
def
forward_impl_chunked
(
self
,
full_hidden_states
:
torch
.
Tensor
,
full_router_logits
:
torch
.
Tensor
):
full_router_logits
:
torch
.
Tensor
):
...
@@ -1522,7 +1527,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -1522,7 +1527,9 @@ class FusedMoE(torch.nn.Module):
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward_impl
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
):
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
if
(
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
):
...
@@ -1559,7 +1566,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -1559,7 +1566,9 @@ class FusedMoE(torch.nn.Module):
shared_output
=
shared_output
,
shared_output
=
shared_output
,
use_nn_moe
=
self
.
use_nn_moe
,
use_nn_moe
=
self
.
use_nn_moe
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
use_fused_gate
=
self
.
use_fused_gate
use_fused_gate
=
self
.
use_fused_gate
,
i_q
=
i_q
,
i_s
=
i_s
)
)
if
do_naive_dispatch_combine
:
if
do_naive_dispatch_combine
:
...
@@ -1630,16 +1639,20 @@ class FusedMoE(torch.nn.Module):
...
@@ -1630,16 +1639,20 @@ class FusedMoE(torch.nn.Module):
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
forward_context
.
no_compile_layers
[
layer_name
]
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
)
return
self
.
forward_impl
(
hidden_states
,
router_logits
,
shared_output
,
i_q
,
i_s
)
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
def
moe_forward_fake
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
layer_name
:
str
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
...
...
vllm/model_executor/layers/linear.py
View file @
68972532
...
@@ -406,7 +406,9 @@ class ReplicatedLinear(LinearBase):
...
@@ -406,7 +406,9 @@ class ReplicatedLinear(LinearBase):
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_args
:
Optional
[
list
]
=
None
,
quant_args
:
Optional
[
list
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
],
list
[
torch
.
Tensor
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
(
rms_weight
is
not
None
or
quant_args
is
not
None
):
if
envs
.
USE_FUSED_RMS_QUANT
and
(
rms_weight
is
not
None
or
quant_args
is
not
None
):
if
quant_args
is
not
None
:
if
quant_args
is
not
None
:
input_quant_args
=
quant_args
input_quant_args
=
quant_args
...
@@ -601,7 +603,9 @@ class ColumnParallelLinear(LinearBase):
...
@@ -601,7 +603,9 @@ class ColumnParallelLinear(LinearBase):
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
update_hd
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
input_quant_args
=
None
assert
rms_weight
is
not
None
assert
rms_weight
is
not
None
...
@@ -680,7 +684,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -680,7 +684,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
,
update_hd
:
Optional
[
bool
]
=
True
,
xqxs
:
Optional
[
tuple
]
=
None
xqxs
:
Optional
[
tuple
]
=
None
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]],
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
Parameter
]],
]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
input_quant_args
=
None
assert
residual
is
not
None
and
rms_weight
is
not
None
assert
residual
is
not
None
and
rms_weight
is
not
None
...
@@ -707,7 +714,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -707,7 +714,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
return_bias
:
if
not
self
.
return_bias
:
return
output
return
output
return
output
,
new_residual
,
output_bias
return
output
,
new_residual
,
i_q
,
_scales
,
output_bias
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
68972532
...
@@ -670,7 +670,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -670,7 +670,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
):
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
):
"""
"""
Use the output of create_weights and the CompressedTensorsScheme
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
associated with the layer to apply the forward pass with the
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
68972532
...
@@ -1096,6 +1096,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1096,6 +1096,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -1137,7 +1139,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1137,7 +1139,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
68972532
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
enum
import
enum
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
,
List
import
torch
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -163,6 +163,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -163,6 +163,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -203,5 +205,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -203,5 +205,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
68972532
...
@@ -113,7 +113,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -113,7 +113,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
bias
:
Optional
[
torch
.
Tensor
],
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# return self.kernel.apply_weights(layer, x, bias)
# return self.kernel.apply_weights(layer, x, bias)
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
68972532
...
@@ -156,7 +156,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -156,7 +156,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
):
):
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
assert
len
(
input_quant_args
)
==
2
assert
len
(
input_quant_args
)
==
2
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
68972532
...
@@ -407,7 +407,7 @@ def apply_int8_linear(
...
@@ -407,7 +407,7 @@ def apply_int8_linear(
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w8a8_strategy
:
Optional
[
int
]
=
0
,
w8a8_strategy
:
Optional
[
int
]
=
0
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
):
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * dynamic, layer.input_scale is None and x_scale computed from x.
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
68972532
...
@@ -97,16 +97,18 @@ class DeepseekV2MLP(nn.Module):
...
@@ -97,16 +97,18 @@ class DeepseekV2MLP(nn.Module):
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
False
,
update_hd
:
Optional
[
bool
]
=
False
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
):
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
new_resi
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
gate_up
,
new_resi
,
i_q
,
_scales
,
_
=
self
.
gate_up_proj
(
x
,
rms_weight
,
residual
,
update_hd
=
update_hd
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
x
,
_
=
self
.
down_proj
(
gate_up
,
use_fused_silu_mul_quant
=
True
)
x
,
_
=
self
.
down_proj
(
gate_up
,
use_fused_silu_mul_quant
=
True
)
else
:
else
:
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
,
new_resi
return
x
,
new_resi
,
i_q
,
_scales
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
xqxs
=
xqxs
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
...
@@ -210,7 +212,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -210,7 +212,8 @@ class DeepseekV2MoE(nn.Module):
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
xqxs
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
if
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
xqxs
is
not
None
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
@@ -255,9 +258,10 @@ class DeepseekV2MoE(nn.Module):
...
@@ -255,9 +258,10 @@ class DeepseekV2MoE(nn.Module):
else
:
else
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
i_q
,
i_s
=
None
,
None
if
self
.
n_shared_experts
is
not
None
:
if
self
.
n_shared_experts
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
shared_output
,
new_resi
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
shared_output
,
new_resi
,
i_q
,
i_s
=
self
.
shared_experts
(
hidden_states
,
rms_weight
,
residual
,
update_hd
=
True
)
else
:
else
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
...
@@ -268,15 +272,18 @@ class DeepseekV2MoE(nn.Module):
...
@@ -268,15 +272,18 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
shared_output
=
shared_output
)
shared_output
=
shared_output
,
i_q
=
i_q
,
i_s
=
i_s
)
else
:
else
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
router_logits
=
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
*
self
.
routed_scaling_factor
else
:
else
:
# Fix FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
router_logits
=
router_logits
)
...
@@ -298,7 +305,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -298,7 +305,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states
))
final_hidden_states
))
if
envs
.
USE_FUSED_RMS_QUANT
:
if
envs
.
USE_FUSED_RMS_QUANT
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
),
new_resi
,
i_q
,
i_s
else
:
else
:
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
@@ -614,8 +621,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -614,8 +621,7 @@ class DeepseekV2MLAAttention(nn.Module):
update_input
:
Optional
[
bool
]
=
True
update_input
:
Optional
[
bool
]
=
True
)
->
Union
[
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]:
]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
q_c
,
new_residual
,
_
,
input_quant_args
=
self
.
q_a_proj
(
hidden_states
,
rms_weight
=
rms_weight
,
residual
=
residual
,
update_hd
=
False
)
...
@@ -816,7 +822,10 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -816,7 +822,10 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
new_resi
=
self
.
mlp
(
hidden_states
,
self
.
post_attention_layernorm
.
weight
.
data
,
residual
)
hidden_states
,
new_resi
,
_i_q
,
_scales
=
self
.
mlp
(
hidden_states
,
rms_weight
=
self
.
post_attention_layernorm
.
weight
.
data
,
residual
=
residual
,
)
if
isinstance
(
self
.
mlp
,
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment