Unverified Commit 145c2ff6 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[Bugfix] Revert MoE Triton Config Default (#12629)

SUMMARY:
* previous PR for pulling in block configs also changed defaults
(https://github.com/vllm-project/vllm/pull/11589/files

) for FP8
* this broke L4 MoE since there was not enough SHM for the default
configuration
* this reverts the non-block example to the default
Signed-off-by: default avatarrshaw@neuralmagic.com <rshaw@neuralmagic.com>
parent 415f1947
......@@ -660,26 +660,7 @@ def get_default_config(
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
if dtype == "fp8_w8a8":
if block_shape is None:
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
}
else:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment