"src/vscode:/vscode.git/clone" did not exist on "f07a16e09bb5b1cf4fa2306bfa4ea791f24fa968"
Commit 890b6aa7 authored by Casper Hansen's avatar Casper Hansen
Browse files

GEMM + GEMV compatibility

parent 5297eccc
...@@ -111,20 +111,31 @@ class LlamaFuser: ...@@ -111,20 +111,31 @@ class LlamaFuser:
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear_GEMV( if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit, q_proj.w_bit,
q_proj.group_size, q_proj.group_size,
q_proj.in_features, q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None, q_proj.bias is not None,
q_proj.qweight.device, next(iter(module.state_dict().values())).device
) )
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias qkv_layer.bias = bias
qkv_layer.split_k_iters = q_proj.split_k_iters
return qkv_layer return qkv_layer
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine import awq_inference_engine
import torch.nn.functional as F import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantMPTMLP(nn.Module): class QuantMPTMLP(nn.Module):
def __init__( def __init__(
...@@ -18,15 +19,22 @@ class QuantMPTMLP(nn.Module): ...@@ -18,15 +19,22 @@ class QuantMPTMLP(nn.Module):
self.up_proj = up_proj self.up_proj = up_proj
self.act = act self.act = act
self.down_proj = down_proj self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemv_forward_cuda( x = self.linear(
x, x,
self.up_proj_qweight, self.up_proj_qweight,
self.up_proj_scales, self.up_proj_scales,
self.up_proj_qzeros, self.up_proj_qzeros,
self.down_proj.group_size self.group_size
) )
return self.down_proj(self.act(x)) return self.down_proj(self.act(x))
...@@ -37,7 +45,7 @@ class QuantLlamaMLP(nn.Module): ...@@ -37,7 +45,7 @@ class QuantLlamaMLP(nn.Module):
self, self,
gate_proj, gate_proj,
down_proj, down_proj,
up_proj, up_proj
): ):
super().__init__() super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight) self.register_buffer('gate_proj_qweight', gate_proj.qweight)
...@@ -53,22 +61,29 @@ class QuantLlamaMLP(nn.Module): ...@@ -53,22 +61,29 @@ class QuantLlamaMLP(nn.Module):
self.w_bit = gate_proj.w_bit self.w_bit = gate_proj.w_bit
self.down_proj = down_proj self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_inference_engine.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,) out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemv_forward_cuda( gate_output = self.linear(
x, x,
self.gate_proj_qweight, self.gate_proj_qweight,
self.gate_proj_scales, self.gate_proj_scales,
self.gate_proj_qzeros, self.gate_proj_qzeros,
self.down_proj.group_size, self.group_size,
) )
up_output = awq_inference_engine.gemv_forward_cuda( up_output = self.linear(
x, x,
self.up_proj_qweight, self.up_proj_qweight,
self.up_proj_scales, self.up_proj_scales,
self.up_proj_qzeros, self.up_proj_qzeros,
self.down_proj.group_size, self.group_size,
) )
x = F.silu(gate_output) * up_output x = F.silu(gate_output) * up_output
x = x.reshape(out_shape) x = x.reshape(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment