"...git@developer.sourcefind.cn:2222/tsoc/superbenchmark.git" did not exist on "d03d110f55d3afa1d4853199f85d8a844fc9e6eb"
Unverified Commit 80d38b8a authored by TJian's avatar TJian Committed by GitHub
Browse files

[V1] [ROCm] [AITER] Upgrade AITER to commit `916bf3c` and bugfix APIs (#20880)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 211b6a61
...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ...@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa" ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6487649" ARG AITER_BRANCH="916bf3c"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base FROM ${BASE_IMAGE} AS base
......
...@@ -8,11 +8,55 @@ import torch ...@@ -8,11 +8,55 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod @classmethod
...@@ -111,10 +155,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -111,10 +155,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
"does not support AITER block scaled GEMM.") "does not support AITER block scaled GEMM.")
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K] # a to be [M, K]
# b to be [N, K] # b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
bias, out_dtype)
...@@ -56,7 +56,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( ...@@ -56,7 +56,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
) -> torch.Tensor: ) -> torch.Tensor:
import aiter as rocm_aiter import aiter as rocm_aiter
return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def rocm_aiter_gemm_w8a8_blockscale_fake( def rocm_aiter_gemm_w8a8_blockscale_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment