Commit 56983e67 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决gptq的不能开启graph的问题

parent e4bff95c
...@@ -614,6 +614,34 @@ def gptq_gemm( ...@@ -614,6 +614,34 @@ def gptq_gemm(
# use_v2_format, # use_v2_format,
# bit, # bit,
# ) # )
# return quant_ops.gptq_gemm(
# a,
# b_q_weight,
# b_gptq_qzeros,
# b_gptq_scales,
# b_g_idx,
# use_exllama,
# bit,
# )
return torch.ops.vllm.gptq_gemm_(
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
b_g_idx,
use_exllama,
bit,
)
def gptq_gemm_(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
bit: int) -> torch.Tensor:
return quant_ops.gptq_gemm( return quant_ops.gptq_gemm(
a, a,
b_q_weight, b_q_weight,
...@@ -621,25 +649,36 @@ def gptq_gemm( ...@@ -621,25 +649,36 @@ def gptq_gemm(
b_gptq_scales, b_gptq_scales,
b_g_idx, b_g_idx,
use_exllama, use_exllama,
bit) bit,
)
def gptq_gemm_fake_(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.empty((a.shape[0], b_gptq_scales.shape[1]), dtype=a.dtype, device=a.device)
if hasattr(torch.ops._C, "gptq_gemm"):
@register_fake("_C::gptq_gemm") # if hasattr(torch.ops._C, "gptq_gemm"):
def _gptq_gemm_fake(
a: torch.Tensor, # @register_fake("_C::gptq_gemm")
b_q_weight: torch.Tensor, # def _gptq_gemm_fake(
b_gptq_qzeros: torch.Tensor, # a: torch.Tensor,
b_gptq_scales: torch.Tensor, # b_q_weight: torch.Tensor,
b_g_idx: torch.Tensor, # b_gptq_qzeros: torch.Tensor,
use_exllama: bool, # b_gptq_scales: torch.Tensor,
use_v2_format: bool, # b_g_idx: torch.Tensor,
bit: int, # use_exllama: bool,
) -> torch.Tensor: # use_v2_format: bool,
return torch.empty( # bit: int,
(a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device # ) -> torch.Tensor:
) # return torch.empty(
# (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device
# )
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
...@@ -3352,3 +3391,10 @@ direct_register_custom_op( ...@@ -3352,3 +3391,10 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=awq_gemm_fake, fake_impl=awq_gemm_fake,
) )
direct_register_custom_op(
op_name="gptq_gemm_",
op_func=gptq_gemm_,
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment