"vscode:/vscode.git/clone" did not exist on "de02b07db4741cc9ed40b8262d7a67e6bce30211"
Unverified Commit 1166c31c authored by Dongjie Zou's avatar Dongjie Zou Committed by GitHub
Browse files

[Bugfix]: Fix glm46 awq marlin moe wna16 compatibility (#30210)


Signed-off-by: default avatarbaonudesifeizhai <baonudesifeizhai@gmail.com>
parent 03416ead
...@@ -895,6 +895,48 @@ def get_moe_configs( ...@@ -895,6 +895,48 @@ def get_moe_configs(
return None return None
def _ensure_block_size_k_divisible(
size_k: int, block_size_k: int, group_size: int
) -> int:
"""Ensure block_size_k is a divisor of size_k and divisible by group_size.
This ensures BLOCK_SIZE_K compatibility with MoeWNA16 CUDA kernel which
requires size_k % BLOCK_SIZE_K == 0 and BLOCK_SIZE_K % group_size == 0.
Args:
size_k: The size_k dimension that must be divisible by result.
block_size_k: Preferred block size (will be adjusted if needed).
group_size: The result must be divisible by this.
Returns:
A valid BLOCK_SIZE_K that divides size_k and is divisible by group_size.
"""
# Fast path: already valid
if size_k % block_size_k == 0 and block_size_k % group_size == 0:
return block_size_k
# Find the largest value that:
# 1. Divides size_k (size_k % candidate == 0)
# 2. Is divisible by group_size (candidate % group_size == 0)
# 3. Is <= block_size_k (prefer smaller values close to block_size_k)
#
# Strategy: Search from min(block_size_k, size_k) down to group_size,
# stepping by group_size to ensure divisibility by group_size
max_search = min(block_size_k, size_k)
start = (max_search // group_size) * group_size
for candidate in range(start, group_size - 1, -group_size):
if size_k % candidate == 0:
return candidate
# Fallback: if group_size divides size_k, use it
# This should always be true with correct group_size configuration
if size_k % group_size == 0:
return group_size
# This should not happen with correct group_size, but ensure divisibility
return size_k
def get_moe_wna16_block_config( def get_moe_wna16_block_config(
config: dict[str, int], config: dict[str, int],
use_moe_wna16_cuda: bool, use_moe_wna16_cuda: bool,
...@@ -960,6 +1002,9 @@ def get_moe_wna16_block_config( ...@@ -960,6 +1002,9 @@ def get_moe_wna16_block_config(
# at the same time. # at the same time.
block_size_n = 1024 block_size_n = 1024
# Ensure BLOCK_SIZE_K is a divisor of size_k for CUDA kernel compatibility
block_size_k = _ensure_block_size_k_divisible(size_k, block_size_k, group_size)
return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k}
......
...@@ -60,7 +60,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -60,7 +60,7 @@ class MoeWNA16Config(QuantizationConfig):
if self.linear_quant_method == "gptq": if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
capability_tuple = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = ( device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int() -1 if capability_tuple is None else capability_tuple.to_int()
...@@ -107,7 +107,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -107,7 +107,7 @@ class MoeWNA16Config(QuantizationConfig):
if linear_quant_method == "gptq": if linear_quant_method == "gptq":
has_zp = not cls.get_from_keys(config, ["sym"]) has_zp = not cls.get_from_keys(config, ["sym"])
modules_to_not_convert = [] modules_to_not_convert = []
elif linear_quant_method == "awq": elif linear_quant_method in ("awq", "awq_marlin"):
has_zp = cls.get_from_keys(config, ["zero_point"]) has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys_or( modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None config, ["modules_to_not_convert"], None
...@@ -184,7 +184,7 @@ class MoeWNA16Config(QuantizationConfig): ...@@ -184,7 +184,7 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config(self.full_config).get_quant_method( return GPTQConfig.from_config(self.full_config).get_quant_method(
layer, prefix layer, prefix
) )
elif self.linear_quant_method == "awq": elif self.linear_quant_method in ("awq", "awq_marlin"):
if self.use_marlin and check_marlin_supports_layer( if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size layer, self.group_size
): ):
...@@ -468,7 +468,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -468,7 +468,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format # convert gptq and awq weight to a standard format
if layer.quant_config.linear_quant_method == "awq": # awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4 assert layer.quant_config.weight_bits == 4
if "weight" in weight_name: if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight") loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment