Unverified Commit 5a499e70 authored by Hashem Hashemi's avatar Hashem Hashemi Committed by GitHub
Browse files

[Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs (#17071)


Signed-off-by: default avatarHashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: default avatarcharlifu <charlifu@amd.com>
Co-authored-by: default avatarcharlifu <charlifu@amd.com>
parent 6930a411
This diff is collapsed.
......@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
DTYPES = [torch.bfloat16, torch.float16]
M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192]
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0
K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0
N = [1, 2, 3, 4]
SEEDS = [0]
......
......@@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n < 4:
if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
......
......@@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id
@cache
def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment