Unverified Commit 5be8f1ed authored by yigex's avatar yigex Committed by GitHub
Browse files

ROCM: AITER BLOCK GEMM (#4075)

parent e5760bc4
...@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import get_bool_env_var, is_hip
is_hip_ = is_hip() is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
...@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear( ...@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm( output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
) )
elif is_hip_ and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input.dtype,
device=q_input.device,
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else: else:
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment