Unverified Commit e626d286 authored by TJian's avatar TJian Committed by GitHub
Browse files

[FEAT] [ROCm] [AITER]: Add AITER HIP block quant kernel (#21242)

parent c7ffe93d
...@@ -82,6 +82,13 @@ if current_platform.is_rocm(): ...@@ -82,6 +82,13 @@ if current_platform.is_rocm():
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()):
import aiter as rocm_aiter
from aiter import get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
def dispatch_w8a8_blockscale_func( def dispatch_w8a8_blockscale_func(
...@@ -178,8 +185,12 @@ def apply_w8a8_block_fp8_linear( ...@@ -178,8 +185,12 @@ def apply_w8a8_block_fp8_linear(
block_size, input.dtype) block_size, input.dtype)
else: else:
q_input, x_scale = per_token_group_quant_fp8( if use_aiter_and_is_supported:
input_2d, block_size[1], column_major_scales=use_cutlass) q_input, x_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
block_size, input.dtype) block_size, input.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment