Unverified Commit 97abeb1d authored by Duncan Moss's avatar Duncan Moss Committed by GitHub
Browse files

[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)


Signed-off-by: default avatarDuncan Moss <djm.moss@gmail.com>
parent 34dad19e
......@@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
return out.to(dtype=out_dtype)
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
return M >= 128 and N % 128 == 0 and K % 128 == 0
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
m = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
return False
......
......@@ -1180,7 +1180,7 @@ def fused_experts(
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
assert apply_router_weight_on_input is False
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment