Unverified Commit 934acdde authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[Perf] fused_moe: add int4_w4a16 benchmark support and tuning config (#34130)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent 742d214d
...@@ -100,13 +100,38 @@ def benchmark_config( ...@@ -100,13 +100,38 @@ def benchmark_config(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
num_iters: int = 100, num_iters: int = 100,
block_quant_shape: list[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16: if use_int4_w4a16:
# Int4 packed weights: 2 int4 values per uint8 byte
# K dimension is packed (halved)
intermediate_size = shard_intermediate_size // 2 # after silu_and_mul
w1 = torch.randint(
0,
255,
(
num_experts,
shard_intermediate_size,
hidden_size // 2, # int4 packing
),
dtype=torch.uint8,
)
w2 = torch.randint(
0,
255,
(
num_experts,
hidden_size,
intermediate_size // 2, # int4 packing
),
dtype=torch.uint8,
)
elif use_int8_w8a16:
w1 = torch.randint( w1 = torch.randint(
-127, -127,
127, 127,
...@@ -140,7 +165,20 @@ def benchmark_config( ...@@ -140,7 +165,20 @@ def benchmark_config(
w2_scale = None w2_scale = None
a1_scale = None a1_scale = None
a2_scale = None a2_scale = None
if use_int8_w8a16: if use_int4_w4a16:
if block_quant_shape is None:
raise ValueError("block_quant_shape is required for int4_w4a16")
group_size = block_quant_shape[1]
# Scales shape: (E, N, K // group_size) in fp16
w1_scale = torch.rand(
(num_experts, shard_intermediate_size, hidden_size // group_size),
dtype=dtype,
)
w2_scale = torch.rand(
(num_experts, hidden_size, intermediate_size // group_size),
dtype=dtype,
)
elif use_int8_w8a16:
w1_scale = torch.randn( w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32 (num_experts, 2 * shard_intermediate_size), dtype=torch.float32
) )
...@@ -199,6 +237,7 @@ def benchmark_config( ...@@ -199,6 +237,7 @@ def benchmark_config(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_quant_shape, block_shape=block_quant_shape,
weight_dtype="int4" if use_int4_w4a16 else None,
) )
deep_gemm_experts = None deep_gemm_experts = None
...@@ -481,6 +520,7 @@ class BenchmarkWorker: ...@@ -481,6 +520,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool = False,
block_quant_shape: list[int] = None, block_quant_shape: list[int] = None,
use_deep_gemm: bool = False, use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]: ) -> tuple[dict[str, int], float]:
...@@ -488,7 +528,10 @@ class BenchmarkWorker: ...@@ -488,7 +528,10 @@ class BenchmarkWorker:
set_random_seed(self.seed) set_random_seed(self.seed)
dtype_str = _get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
...@@ -519,6 +562,7 @@ class BenchmarkWorker: ...@@ -519,6 +562,7 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
num_iters=100, num_iters=100,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm, use_deep_gemm=use_deep_gemm,
...@@ -535,6 +579,7 @@ class BenchmarkWorker: ...@@ -535,6 +579,7 @@ class BenchmarkWorker:
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool,
search_space: list[dict[str, int]], search_space: list[dict[str, int]],
block_quant_shape: list[int], block_quant_shape: list[int],
use_deep_gemm: bool, use_deep_gemm: bool,
...@@ -545,7 +590,7 @@ class BenchmarkWorker: ...@@ -545,7 +590,7 @@ class BenchmarkWorker:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
search_space = prune_rocm_search_space( search_space = prune_rocm_search_space(
num_tokens, num_tokens,
shard_intermediate_size, shard_intermediate_size,
...@@ -574,6 +619,7 @@ class BenchmarkWorker: ...@@ -574,6 +619,7 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
num_iters=20, num_iters=20,
block_quant_shape=block_quant_shape, block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm, use_deep_gemm=use_deep_gemm,
...@@ -621,6 +667,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: ...@@ -621,6 +667,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
else {} else {}
), ),
**({"kpack": config["kpack"]} if "kpack" in config else {}), **({"kpack": config["kpack"]} if "kpack" in config else {}),
**({"SPLIT_K": config["SPLIT_K"]} if "SPLIT_K" in config else {}),
} }
...@@ -633,11 +680,15 @@ def save_configs( ...@@ -633,11 +680,15 @@ def save_configs(
dtype: torch.dtype, dtype: torch.dtype,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_quant_shape: list[int], block_quant_shape: list[int],
save_dir: str, save_dir: str,
) -> None: ) -> None:
dtype_str = _get_config_dtype_str( dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int4_w4a16=use_int4_w4a16,
) )
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
...@@ -739,6 +790,38 @@ def get_model_params(config): ...@@ -739,6 +790,38 @@ def get_model_params(config):
return E, topk, intermediate_size, hidden_size return E, topk, intermediate_size, hidden_size
def get_quantization_group_size(config) -> int | None:
"""Extract the quantization group size from the HF model config.
This reads directly from the HuggingFace config object (as returned by
``get_config()``), not from vLLM's quantization config classes.
Supports AWQ/GPTQ-style configs (direct 'group_size' key) and
compressed-tensors configs (nested inside 'config_groups').
"""
quantization_config = getattr(config, "quantization_config", {})
if not isinstance(quantization_config, dict):
return None
# AWQ / GPTQ style: group_size is a top-level key
gs = quantization_config.get("group_size")
if gs is not None:
return gs
# compressed-tensors style: group_size is nested in config_groups
config_groups = quantization_config.get("config_groups", {})
if not isinstance(config_groups, dict):
return None
for group_cfg in config_groups.values():
if not isinstance(group_cfg, dict):
continue
weights = group_cfg.get("weights", {})
if not isinstance(weights, dict):
continue
gs = weights.get("group_size")
if gs is not None:
return gs
return None
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
...@@ -757,7 +840,20 @@ def main(args: argparse.Namespace): ...@@ -757,7 +840,20 @@ def main(args: argparse.Namespace):
dtype = torch.float16 if current_platform.is_rocm() else config.dtype dtype = torch.float16 if current_platform.is_rocm() else config.dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
use_int4_w4a16 = args.dtype == "int4_w4a16"
block_quant_shape = get_weight_block_size_safety(config) block_quant_shape = get_weight_block_size_safety(config)
if use_int4_w4a16:
group_size = get_quantization_group_size(config)
if group_size is None:
raise ValueError(
"Could not determine group_size from model config. "
"The model's quantization_config must contain a 'group_size' "
"field (AWQ/GPTQ) or 'config_groups.*.weights.group_size' "
"(compressed-tensors)."
)
# For int4_w4a16, block_shape = [0, group_size]
# block_shape[0]=0 means no block quantization on N dimension
block_quant_shape = [0, group_size]
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
...@@ -811,8 +907,20 @@ def main(args: argparse.Namespace): ...@@ -811,8 +907,20 @@ def main(args: argparse.Namespace):
return ray.get(outputs) return ray.get(outputs)
if args.tune: if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) # int4_w4a16 weights are uint8-packed, not fp16; treat like fp8 for
search_space = get_configs_compute_bound(is_fp16, block_quant_shape) # search space generation (no matrix_instr_nonkdim/kpack exploration).
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 or use_int4_w4a16)
# For int4_w4a16, the group_size constraint on BLOCK_SIZE_K does not
# apply: the gptq_awq kernel handles arbitrary BLOCK_SIZE_K regardless
# of group_size. Skip block_quant_shape filtering to keep the full
# search space (e.g. BLOCK_SIZE_K=64 with group_size=128).
tune_block_quant_shape = None if use_int4_w4a16 else block_quant_shape
search_space = get_configs_compute_bound(is_fp16, tune_block_quant_shape)
if use_int4_w4a16:
# SPLIT_K is a required kernel constexpr for gptq_awq kernel;
# only SPLIT_K=1 is used at runtime, so fix it during tuning.
for cfg in search_space:
cfg["SPLIT_K"] = 1
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
if use_deep_gemm: if use_deep_gemm:
raise ValueError( raise ValueError(
...@@ -832,6 +940,7 @@ def main(args: argparse.Namespace): ...@@ -832,6 +940,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
search_space, search_space,
block_quant_shape, block_quant_shape,
use_deep_gemm, use_deep_gemm,
...@@ -851,6 +960,7 @@ def main(args: argparse.Namespace): ...@@ -851,6 +960,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
block_quant_shape, block_quant_shape,
args.save_dir, args.save_dir,
) )
...@@ -869,6 +979,7 @@ def main(args: argparse.Namespace): ...@@ -869,6 +979,7 @@ def main(args: argparse.Namespace):
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16,
block_quant_shape, block_quant_shape,
use_deep_gemm, use_deep_gemm,
) )
...@@ -891,7 +1002,10 @@ if __name__ == "__main__": ...@@ -891,7 +1002,10 @@ if __name__ == "__main__":
) )
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true") parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" "--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int4_w4a16"],
default="auto",
) )
parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument( parser.add_argument(
......
{
"triton_version": "3.6.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 4
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"SPLIT_K": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 2
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment