Unverified Commit e9add129 authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[Bugfix] awq_gemm: fix argument order swap (#30364)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 3224ea99
...@@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): ...@@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
qweight = torch.randint( qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
) )
scales = torch.randint( scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
qzeros = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
) )
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8 split_k_iters = 8
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))
...@@ -498,15 +498,15 @@ def awq_dequantize( ...@@ -498,15 +498,15 @@ def awq_dequantize(
def awq_gemm( def awq_gemm(
input: torch.Tensor, input: torch.Tensor,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int, split_k_iters: int,
) -> torch.Tensor: ) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ: if envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
# gptq # gptq
...@@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -632,8 +632,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _awq_gemm_fake( def _awq_gemm_fake(
input: torch.Tensor, input: torch.Tensor,
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: torch.SymInt, split_k_iters: torch.SymInt,
) -> torch.Tensor: ) -> torch.Tensor:
num_in_feats = input.size(0) num_in_feats = input.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment