Commit 89db76fd authored by chenyue3's avatar chenyue3
Browse files

解决gptq的不能开启graph的问题

parent 37771741
......@@ -765,22 +765,33 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
#return quant_ops.gptq_gemm_(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
return torch.ops.vllm.gptq_gemm_(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_gemm_(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
def gptq_gemm_fake_(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.empty((a.shape[0], b_gptq_scales.shape[1]), dtype=a.dtype, device=a.device)
if hasattr(torch.ops._C, "gptq_gemm"):
# if hasattr(torch.ops._C, "gptq_gemm"):
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, bit: int) -> torch.Tensor:
return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype,
device=a.device)
# @register_fake("_C::gptq_gemm")
# def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_gptq_qzeros: torch.Tensor,
# b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
# use_exllama: bool, bit: int) -> torch.Tensor:
# return torch.empty((a.size(0), b_q_weight.size(1)),
# dtype=a.dtype,
# device=a.device)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
......@@ -2601,3 +2612,10 @@ direct_register_custom_op(
mutates_args=[],
fake_impl=awq_gemm_fake,
)
direct_register_custom_op(
op_name="gptq_gemm_",
op_func=gptq_gemm_,
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment