Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22481fbf
Unverified
Commit
22481fbf
authored
May 09, 2025
by
Michael Goin
Committed by
GitHub
May 09, 2025
Browse files
Update CT WNA16MarlinMoE integration (#16666)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
5c4c08f6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
81 deletions
+38
-81
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+38
-81
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
22481fbf
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
enum
import
enum
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
Optional
import
torch
import
torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors
import
CompressionFormat
...
@@ -14,9 +14,12 @@ from vllm import _custom_ops as ops
...
@@ -14,9 +14,12 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
.compressed_tensors_wNa16
import
(
# noqa
WNA16_SUPPORTED_BITS
)
WNA16_SUPPORTED_BITS
,
WNA16_SUPPORTED_TYPES_MAP
)
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils
import
replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -54,18 +57,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -54,18 +57,19 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"input_activations"
)
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
# Prefer to use the
non-m
arlin kernel when
:
# Prefer to use the
M
arlin
MoE
kernel when
it is supported.
# 1. Many experts (MarlinMoE gives poor performance when >= 16)
if
not
check_moe_marlin_supports_layer
(
layer
,
# 2. Non-FP16 dtype (MarlinMoE only supports FP16)
weight_quant
.
group_size
):
# 3. Actorder is not group/dynamic (g_idx is unsupported)
if
(
weight_quant
.
strategy
in
QuantizationStrategy
.
GROUP
and
# 4. Scaled are grouped (channelwise is unsupported)
weight_quant
.
actorder
in
(
ActivationOrdering
.
GROUP
,
if
((
layer
.
local_num_experts
>=
16
ActivationOrdering
.
DYNAMIC
)):
or
layer
.
params_dtype
!=
torch
.
float16
)
and
raise
ValueError
(
weight_quant
.
actorder
not
in
(
ActivationOrdering
.
GROUP
,
"WNA16MoE is not supported with actorder=group/dynamic."
ActivationOrdering
.
DYNAMIC
)
)
and
weight_quant
.
strategy
in
QuantizationStrategy
.
GROUP
):
logger
.
info_once
(
"Using CompressedTensorsWNA16MoEMethod"
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
else
:
else
:
logger
.
info_once
(
"Using CompressedTensorsWNA16MarlinMoEMethod"
)
return
CompressedTensorsWNA16MarlinMoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MarlinMoEMethod
(
quant_config
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
elif
(
quant_config
.
_is_fp8_w8a8_sm90
(
weight_quant
,
input_quant
)
and
layer
.
activation
==
"silu"
):
and
layer
.
activation
==
"silu"
):
...
@@ -705,15 +709,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -705,15 +709,12 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
self
.
quant_type
=
WNA16_SUPPORTED_TYPES_MAP
[
self
.
num_bits
]
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
assert
params_dtype
==
torch
.
float16
,
(
"float16 is required for MoE compressed models. Set dtype=torch.float16"
# noqa: E501
)
intermediate_size_full
=
extra_weight_attrs
.
pop
(
intermediate_size_full
=
extra_weight_attrs
.
pop
(
"intermediate_size_full"
)
"intermediate_size_full"
)
...
@@ -837,50 +838,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -837,50 +838,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
((
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
,
num_bits
)
return
output
size_k2
=
layer
.
w2_weight_packed
.
shape
[
2
]
size_k13
=
layer
.
w13_weight_packed
.
shape
[
2
]
num_experts
=
layer
.
w13_weight_g_idx
.
shape
[
0
]
num_experts
=
layer
.
w13_weight_g_idx
.
shape
[
0
]
device
=
layer
.
w13_weight_g_idx
.
device
device
=
layer
.
w13_weight_g_idx
.
device
...
@@ -938,7 +895,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -938,7 +895,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer
.
w13_weight_packed
.
shape
[
2
],
layer
.
w13_weight_packed
.
shape
[
2
],
self
.
num_bits
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
replace_
parameter
(
layer
,
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
...
@@ -946,25 +903,25 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -946,25 +903,25 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_weight_packed
.
shape
[
2
],
layer
.
w2_weight_packed
.
shape
[
2
],
self
.
num_bits
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w2_weight_packed"
,
marlin_w2_qweight
)
replace_
parameter
(
layer
,
"w2_weight_packed"
,
marlin_w2_qweight
)
# Repack scales
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
marlin_w13_scales
=
marlin_moe_permute_scales
(
layer
.
w13_weight_scale
,
s
=
layer
.
w13_weight_scale
,
size_k13
,
size_k
=
layer
.
w13_weight_packed
.
shape
[
2
],
layer
.
w13_weight_scale
.
shape
[
2
],
size_n
=
layer
.
w13_weight_scale
.
shape
[
2
],
self
.
group_size
,
group_size
=
self
.
group_size
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w13_weight_scale"
,
marlin_w13_scales
)
replace_
parameter
(
layer
,
"w13_weight_scale"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
marlin_w2_scales
=
marlin_moe_permute_scales
(
layer
.
w2_weight_scale
,
s
=
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
.
shape
[
1
]
*
size_k
=
layer
.
w2_weight_scale
.
shape
[
1
]
*
(
self
.
group_size
if
self
.
group_size
!=
-
1
else
self
.
packed_factor
),
(
self
.
group_size
if
self
.
group_size
!=
-
1
else
self
.
packed_factor
),
size_k2
,
size_n
=
layer
.
w2_weight_scale
.
shape
[
2
],
self
.
group_size
,
group_size
=
self
.
group_size
,
self
.
num_bits
,
)
)
replace_tensor
(
"w2_weight_scale"
,
marlin_w2_scales
)
replace_parameter
(
layer
,
"w2_weight_scale"
,
marlin_w2_scales
)
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -985,10 +942,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -985,10 +942,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
expert_map
is
not
None
:
raise
NotImplementedError
(
"Expert Parallelism is not supported for "
"fused Marlin MoE method."
)
if
apply_router_weight_on_input
:
if
apply_router_weight_on_input
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Apply router weight on input is not supported for "
"Apply router weight on input is not supported for "
...
@@ -1015,11 +968,14 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -1015,11 +968,14 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
router_logits
,
router_logits
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
quant_type_id
=
self
.
quant_type
.
id
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
g_idx1
=
layer
.
w13_weight_g_idx
,
g_idx1
=
layer
.
w13_weight_g_idx
,
g_idx2
=
layer
.
w2_weight_g_idx
,
g_idx2
=
layer
.
w2_weight_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
num_bits
,
workspace
=
layer
.
workspace
,
is_k_full
=
self
.
is_k_full
)
is_k_full
=
self
.
is_k_full
)
...
@@ -1203,7 +1159,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1203,7 +1159,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -1223,6 +1179,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1223,6 +1179,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
inplace
=
True
,
inplace
=
True
,
activation
=
activation
,
use_int4_w4a16
=
self
.
num_bits
==
4
,
use_int4_w4a16
=
self
.
num_bits
==
4
,
use_int8_w8a16
=
self
.
num_bits
==
8
,
use_int8_w8a16
=
self
.
num_bits
==
8
,
global_num_experts
=
global_num_experts
,
global_num_experts
=
global_num_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment