Unverified Commit 1f19d8f8 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Perf] Set split_k to 1 for triton_kernels (#30528)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent cd7740ac
...@@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): ...@@ -57,12 +57,18 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
mx_axis=1, num_warps=num_warps mx_axis=1, num_warps=num_warps
) )
) )
if current_platform.is_cuda() and current_platform.is_device_capability(100): if current_platform.is_cuda():
constraints = { if current_platform.is_device_capability(90):
"is_persistent": True, constraints = {
"epilogue_subtile": 1, "split_k": 1,
} }
opt_flags.update_opt_flags_constraints(constraints) opt_flags.update_opt_flags_constraints(constraints)
elif current_platform.is_device_capability(100):
constraints = {
"is_persistent": True,
"epilogue_subtile": 1,
}
opt_flags.update_opt_flags_constraints(constraints)
# transpose the tensor so that the quantization axis is on dim1 # transpose the tensor so that the quantization axis is on dim1
quant_tensor = quant_tensor.transpose(-2, -1) quant_tensor = quant_tensor.transpose(-2, -1)
scale = scale.transpose(-2, -1) scale = scale.transpose(-2, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment