Commit 890b6aa7 authored by Casper Hansen's avatar Casper Hansen
Browse files

GEMM + GEMV compatibility

parent 5297eccc
......@@ -111,20 +111,31 @@ class LlamaFuser:
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear_GEMV(
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.qweight.device,
next(iter(module.state_dict().values())).device
)
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
qkv_layer.split_k_iters = q_proj.split_k_iters
return qkv_layer
......
......@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import awq_inference_engine
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantMPTMLP(nn.Module):
def __init__(
......@@ -18,15 +19,22 @@ class QuantMPTMLP(nn.Module):
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemv_forward_cuda(
x = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size
self.group_size
)
return self.down_proj(self.act(x))
......@@ -37,7 +45,7 @@ class QuantLlamaMLP(nn.Module):
self,
gate_proj,
down_proj,
up_proj,
up_proj
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
......@@ -53,22 +61,29 @@ class QuantLlamaMLP(nn.Module):
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemv_forward_cuda(
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.down_proj.group_size,
self.group_size,
)
up_output = awq_inference_engine.gemv_forward_cuda(
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size,
self.group_size,
)
x = F.silu(gate_output) * up_output
x = x.reshape(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment