Unverified Commit 449de900 authored by Aleksandr Malyshev's avatar Aleksandr Malyshev Committed by GitHub
Browse files

[ROCm] triton fp8 kernel (#27058)


Signed-off-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarAleksandr Malyshev <maleksan@amd.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
parent d4aa65c9
...@@ -69,30 +69,67 @@ def cutlass_scaled_mm( ...@@ -69,30 +69,67 @@ def cutlass_scaled_mm(
def rocm_aiter_gemm_w8a8_blockscale_impl( def rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor, input_2d: torch.Tensor,
B: torch.Tensor, weight: torch.Tensor,
As: torch.Tensor, input_scale: torch.Tensor,
Bs: torch.Tensor, weight_scale: torch.Tensor,
block_size: list[int], group_size: int,
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake( def rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor, input_2d: torch.Tensor,
B: torch.Tensor, weight: torch.Tensor,
As: torch.Tensor, input_scale: torch.Tensor,
Bs: torch.Tensor, weight_scale: torch.Tensor,
block_size: list[int], group_size: int,
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
m = A.shape[0] m = input_2d.shape[0]
n = B.shape[0] n = weight.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device) return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
return Y
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -101,15 +138,6 @@ if current_platform.is_rocm(): ...@@ -101,15 +138,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_gemm_w8a8_blockscale_impl, op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
) )
if (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
):
import aiter as rocm_aiter
from aiter import get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
# TODO we should be able to change the type of block_size to GroupShape # TODO we should be able to change the type of block_size to GroupShape
...@@ -293,7 +321,9 @@ class W8A8BlockFp8LinearOp: ...@@ -293,7 +321,9 @@ class W8A8BlockFp8LinearOp:
): ):
output = self._run_deepgemm(input_2d, weight, weight_scale) output = self._run_deepgemm(input_2d, weight, weight_scale)
else: else:
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
...@@ -322,7 +352,9 @@ class W8A8BlockFp8LinearOp: ...@@ -322,7 +352,9 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor, input_2d: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d) q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper: if self.is_hopper:
...@@ -350,17 +382,15 @@ class W8A8BlockFp8LinearOp: ...@@ -350,17 +382,15 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor, input_2d: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128) assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input, input_2d,
weight, weight,
input_scale, input_scale,
weight_scale, weight_scale,
list(self.weight_group_shape), self.act_quant_group_shape.col,
input_2d.dtype, input_2d.dtype,
) )
...@@ -369,7 +399,9 @@ class W8A8BlockFp8LinearOp: ...@@ -369,7 +399,9 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor, input_2d: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d) q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
...@@ -391,6 +423,7 @@ class W8A8BlockFp8LinearOp: ...@@ -391,6 +423,7 @@ class W8A8BlockFp8LinearOp:
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor | None,
], ],
torch.Tensor, torch.Tensor,
], ],
...@@ -939,13 +972,11 @@ def requant_weight_ue8m0_inplace( ...@@ -939,13 +972,11 @@ def requant_weight_ue8m0_inplace(
def check_aiter_fp8_linear_support() -> bool: def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ """AITER is only supported on ROCm for MI3XX"""
and at the moment are MI300 series"""
return ( return (
current_platform.is_rocm() current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment