Unverified Commit b361f14e authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due...


[AMD][BugFix] Fix omission  of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 01c753ed
......@@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def get_token_bin_counts_and_mask(
......@@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)
def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
def rocm_unquantized_gemm_impl(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
......@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)
def rocm_unquantized_gemm_impl_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))
def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)
def cpu_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment