Commit 9ce8b1a3 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-blockshape' into 'v0.15.1-dev'

fix: resolve block_shape conflicts between DeepEP MoE and non-DeepEP quantization

See merge request dcutoolkit/deeplearing/vllm!507
parents 22890a8e f9a04c97
......@@ -94,6 +94,9 @@ def _quant_flags_to_group_shape(
# dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
elif block_shape is not None and quant_dtype == torch.int8:
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else:
w_shape = None
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
......@@ -248,7 +251,6 @@ class FusedMoEQuantConfig:
def block_shape(self) -> list[int] | None:
# if self.use_int8_w8a8:
# return [256, 256]
if (
self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR
......@@ -516,9 +518,10 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape
if quant_dtype != torch.int8:
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape
return quant_config
......@@ -560,6 +563,7 @@ def int8_w8a8_moe_quant_config(
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for int8 activations and int8 weights.
......@@ -572,7 +576,7 @@ def int8_w8a8_moe_quant_config(
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=None,
block_shape=block_shape,
)
......
......@@ -311,6 +311,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
block_shape=[256, 256] if self.use_deepep else None,
)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment