Unverified Commit 5be8f1ed authored by yigex's avatar yigex Committed by GitHub
Browse files

ROCM: AITER BLOCK GEMM (#4075)

parent e5760bc4
......@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_bool_env_var, is_hip
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm
......@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif is_hip_ and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input.dtype,
device=q_input.device,
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment