Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c70cf0fe
Unverified
Commit
c70cf0fe
authored
Apr 10, 2025
by
Michael Goin
Committed by
GitHub
Apr 10, 2025
Browse files
[Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (#16038)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
a5d11a54
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
254 additions
and
15 deletions
+254
-15
.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml
...harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+234
-8
No files found.
.buildkite/lm-eval-harness/configs/Qwen1.5-MoE-W4A16-compressed-tensors.yaml
0 → 100644
View file @
c70cf0fe
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
model_name
:
"
nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.31
-
name
:
"
exact_match,flexible-extract"
value
:
0.47
limit
:
1319
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
c70cf0fe
...
@@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
...
@@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Minitron-4B-Base-FP8
.yaml
Qwen1.5-MoE-W4A16-compressed-tensors
.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml
Meta-Llama-3-8B-QQQ.yaml
vllm/model_executor/layers/fused_moe/layer.py
View file @
c70cf0fe
...
@@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -512,7 +512,9 @@ class FusedMoE(torch.nn.Module):
}
}
# need full intermediate size pre-sharding for WNA16 act order
# need full intermediate size pre-sharding for WNA16 act order
if
(
self
.
quant_method
.
__class__
.
__name__
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
@@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -648,9 +650,10 @@ class FusedMoE(torch.nn.Module):
# compressed-tensors checkpoints with packed weights are stored flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
if
self
.
quant_method
.
__class__
.
__name__
in
(
self
.
quant_method
.
__class__
.
__name__
"CompressedTensorsWNA16MarlinMoEMethod"
,
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
"CompressedTensorsWNA16MoEMethod"
):
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
c70cf0fe
...
@@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -96,8 +96,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
.
get_moe_method
(
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
,
layer
)
self
,
layer
.
activation
,
layer
.
expert_map
)
return
None
return
None
@
classmethod
@
classmethod
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
c70cf0fe
...
@@ -6,7 +6,8 @@ from typing import Callable, List, Optional
...
@@ -6,7 +6,8 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
(
ActivationOrdering
,
QuantizationStrategy
)
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
...
@@ -30,9 +31,11 @@ class GPTQMarlinState(Enum):
__all__
=
[
__all__
=
[
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsW8A8Fp8MoECutlassMethod"
,
"CompressedTensorsW8A8Fp8MoECutlassMethod"
,
"CompressedTensorsWNA16MoEMethod"
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
]
]
...
@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -41,8 +44,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@
staticmethod
@
staticmethod
def
get_moe_method
(
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
activation
:
str
,
layer
:
torch
.
nn
.
Module
,
expert_map
:
Optional
[
torch
.
Tensor
],
)
->
"CompressedTensorsMoEMethod"
:
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
# are supported + check if the layer is being ignored.
...
@@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -51,9 +53,21 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations"
)
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
# Prefer to use the non-marlin kernel when:
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
# 3. Actorder is not group/dynamic (g_idx is unsupported)
# 4. Scaled are grouped (channelwise is unsupported)
if
((
layer
.
local_num_experts
>=
16
or
layer
.
params_dtype
!=
torch
.
float16
)
and
weight_quant
.
actorder
not
in
(
ActivationOrdering
.
GROUP
,
ActivationOrdering
.
DYNAMIC
)
and
weight_quant
.
strategy
in
QuantizationStrategy
.
GROUP
):
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
else
:
return
CompressedTensorsWNA16MarlinMoEMethod
(
quant_config
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
and
activation
==
"silu"
and
expert_map
is
None
):
and
layer
.
activation
==
"silu"
and
layer
.
expert_map
is
None
):
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
...
@@ -482,7 +496,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -482,7 +496,7 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsWNA16M
arlinM
oEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -823,3 +837,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -823,3 +837,215 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
num_bits
,
num_bits
=
self
.
num_bits
,
is_k_full
=
self
.
is_k_full
)
is_k_full
=
self
.
is_k_full
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
num_bits
=
config
.
num_bits
self
.
packed_factor
=
32
//
config
.
num_bits
self
.
strategy
=
config
.
strategy
# channelwise is not supported by this kernel
assert
config
.
strategy
==
"group"
self
.
group_size
=
config
.
group_size
# grouped actorder isn't supported by this kernel
assert
config
.
actorder
!=
"group"
assert
config
.
symmetric
,
(
"Only symmetric quantization is supported for MoE"
)
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
self
.
strategy
})
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_packed"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
//
self
.
packed_factor
,
hidden_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_packed"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w2_scales_size
=
intermediate_size_per_partition
if
self
.
strategy
==
"channel"
:
num_groups_w2
=
num_groups_w13
=
1
self
.
group_size
=
-
1
else
:
num_groups_w2
=
w2_scales_size
//
self
.
group_size
num_groups_w13
=
hidden_size
//
self
.
group_size
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_scale
)
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_scale
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
{
"load_full_w2"
:
False
})
w2_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_shape"
,
w2_weight_shape
)
set_weight_attrs
(
w2_weight_shape
,
extra_weight_attrs
)
w13_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_shape"
,
w13_weight_shape
)
set_weight_attrs
(
w13_weight_shape
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Reconfigure packed weights and scales to match moe_wna16 format
layer
.
w13_weight_packed
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight_packed
.
transpose
(
1
,
2
).
contiguous
().
view
(
torch
.
uint8
),
requires_grad
=
False
)
layer
.
w2_weight_packed
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight_packed
.
transpose
(
1
,
2
).
contiguous
().
view
(
torch
.
uint8
),
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight_scale
.
transpose
(
1
,
2
).
contiguous
(),
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_weight_scale
.
transpose
(
1
,
2
).
contiguous
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_int4_w4a16
=
self
.
num_bits
==
4
,
use_int8_w8a16
=
self
.
num_bits
==
8
,
global_num_experts
=
global_num_experts
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w1_zp
=
None
,
w2_zp
=
None
,
block_shape
=
[
0
,
self
.
group_size
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment