Unverified Commit e9add129 authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[Bugfix] awq_gemm: fix argument order swap (#30364)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 3224ea99
......@@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.randint(
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
qzeros = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
)
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
......@@ -498,15 +498,15 @@ def awq_dequantize(
def awq_gemm(
input: torch.Tensor,
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
# gptq
......@@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _awq_gemm_fake(
input: torch.Tensor,
qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: torch.SymInt,
) -> torch.Tensor:
num_in_feats = input.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment