Commit f9a04c97 authored by chenhw5's avatar chenhw5
Browse files

fix: resolve block_shape conflicts between DeepEP MoE and non-DeepEP quantization

parent 22890a8e
...@@ -94,6 +94,9 @@ def _quant_flags_to_group_shape( ...@@ -94,6 +94,9 @@ def _quant_flags_to_group_shape(
# dim should be 1. # dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1]) a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1]) w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
elif block_shape is not None and quant_dtype == torch.int8:
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else: else:
w_shape = None w_shape = None
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
...@@ -248,7 +251,6 @@ class FusedMoEQuantConfig: ...@@ -248,7 +251,6 @@ class FusedMoEQuantConfig:
def block_shape(self) -> list[int] | None: def block_shape(self) -> list[int] | None:
# if self.use_int8_w8a8: # if self.use_int8_w8a8:
# return [256, 256] # return [256, 256]
if ( if (
self._a1.shape is not None self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR and self._a1.shape != GroupShape.PER_TENSOR
...@@ -516,6 +518,7 @@ class FusedMoEQuantConfig: ...@@ -516,6 +518,7 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
), ),
) )
if quant_dtype != torch.int8:
assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape assert quant_config.block_shape == block_shape
...@@ -560,6 +563,7 @@ def int8_w8a8_moe_quant_config( ...@@ -560,6 +563,7 @@ def int8_w8a8_moe_quant_config(
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, a2_scale: torch.Tensor | None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for int8 activations and int8 weights. Construct a quant config for int8 activations and int8 weights.
...@@ -572,7 +576,7 @@ def int8_w8a8_moe_quant_config( ...@@ -572,7 +576,7 @@ def int8_w8a8_moe_quant_config(
a2_scale=a2_scale, a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=block_shape,
) )
......
...@@ -311,6 +311,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -311,6 +311,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
per_act_token_quant=True, per_act_token_quant=True,
block_shape=[256, 256] if self.use_deepep else None,
) )
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment