Commit 39a5084a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-yql' into 'v0.11.0-dev'

解决gptq的不能开启graph的问题

See merge request dcutoolkit/deeplearing/vllm!304
parents b256f7ac 89db76fd
...@@ -765,22 +765,33 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -765,22 +765,33 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
#return quant_ops.gptq_gemm_(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
return torch.ops.vllm.gptq_gemm_(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_gemm_(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, def gptq_gemm_fake_(a: torch.Tensor, b_q_weight: torch.Tensor,
# b_g_idx, use_exllama, bit) b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return torch.empty((a.shape[0], b_gptq_scales.shape[1]), dtype=a.dtype, device=a.device)
if hasattr(torch.ops._C, "gptq_gemm"): # if hasattr(torch.ops._C, "gptq_gemm"):
@register_fake("_C::gptq_gemm") # @register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, # def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, # b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, # b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, bit: int) -> torch.Tensor: # use_exllama: bool, bit: int) -> torch.Tensor:
return torch.empty((a.size(0), b_q_weight.size(1)), # return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype, # dtype=a.dtype,
device=a.device) # device=a.device)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
...@@ -2601,3 +2612,10 @@ direct_register_custom_op( ...@@ -2601,3 +2612,10 @@ direct_register_custom_op(
mutates_args=[], mutates_args=[],
fake_impl=awq_gemm_fake, fake_impl=awq_gemm_fake,
) )
direct_register_custom_op(
op_name="gptq_gemm_",
op_func=gptq_gemm_,
mutates_args=[],
fake_impl=gptq_gemm_fake_,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment