Commit 75710806 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix MLP module

parent e197a733
...@@ -21,7 +21,13 @@ class QuantMPTMLP(nn.Module): ...@@ -21,7 +21,13 @@ class QuantMPTMLP(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8) x = awq_inference_engine.gemv_forward_cuda(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size
)
return self.down_proj(self.act(x)) return self.down_proj(self.act(x))
...@@ -48,18 +54,24 @@ class QuantLlamaMLP(nn.Module): ...@@ -48,18 +54,24 @@ class QuantLlamaMLP(nn.Module):
self.down_proj = down_proj self.down_proj = down_proj
def forward(self, x): def forward(self, x):
return self.down_proj(self.our_llama_mlp(x)) out_shape = x.shape[:-1] + (self.intermediate_size,)
def our_llama_mlp(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size, )
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemm_forward_cuda( gate_output = awq_inference_engine.gemv_forward_cuda(
x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8 x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.down_proj.group_size,
) )
gate_output = F.silu(gate_output) up_output = awq_inference_engine.gemv_forward_cuda(
up_output = awq_inference_engine.gemm_forward_cuda( x,
x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8 self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size,
) )
c = gate_output * up_output x = F.silu(gate_output) * up_output
c = c.reshape(out_shape) x = x.reshape(out_shape)
return c x = self.down_proj(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment