Commit 56983e67 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决gptq的不能开启graph的问题

parent e4bff95c
...@@ -614,33 +614,72 @@ def gptq_gemm( ...@@ -614,33 +614,72 @@ def gptq_gemm(
# use_v2_format, # use_v2_format,
# bit, # bit,
# ) # )
return quant_ops.gptq_gemm(
# return quant_ops.gptq_gemm(
# a,
# b_q_weight,
# b_gptq_qzeros,
# b_gptq_scales,
# b_g_idx,
# use_exllama,
# bit,
# )
return torch.ops.vllm.gptq_gemm_(
a, a,
b_q_weight, b_q_weight,
b_gptq_qzeros, b_gptq_qzeros,
b_gptq_scales, b_gptq_scales,
b_g_idx, b_g_idx,
use_exllama, use_exllama,
bit) bit,
)
if hasattr(torch.ops._C, "gptq_gemm"):
@register_fake("_C::gptq_gemm") def gptq_gemm_(
def _gptq_gemm_fake(
a: torch.Tensor, a: torch.Tensor,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, use_exllama: bool,
use_v2_format: bool, bit: int) -> torch.Tensor:
bit: int, return quant_ops.gptq_gemm(
) -> torch.Tensor: a,
return torch.empty( b_q_weight,
(a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device b_gptq_qzeros,
b_gptq_scales,
b_g_idx,
use_exllama,
bit,
) )
def gptq_gemm_fake_(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.empty((a.shape[0], b_gptq_scales.shape[1]), dtype=a.dtype, device=a.device)
# if hasattr(torch.ops._C, "gptq_gemm"):
# @register_fake("_C::gptq_gemm")
# def _gptq_gemm_fake(
# a: torch.Tensor,
# b_q_weight: torch.Tensor,
# b_gptq_qzeros: torch.Tensor,
# b_gptq_scales: torch.Tensor,
# b_g_idx: torch.Tensor,
# use_exllama: bool,
# use_v2_format: bool,
# bit: int,
# ) -> torch.Tensor:
# return torch.empty(
# (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
# )
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
...@@ -3352,3 +3391,10 @@ direct_register_custom_op( ...@@ -3352,3 +3391,10 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=awq_gemm_fake, fake_impl=awq_gemm_fake,
) )
direct_register_custom_op(
op_name="gptq_gemm_",
op_func=gptq_gemm_,
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment