"vscode:/vscode.git/clone" did not exist on "dc1b6822e77f34e8e33fcde7dbb4f80974047342"
Commit 75710806 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix MLP module

parent e197a733
......@@ -21,7 +21,13 @@ class QuantMPTMLP(nn.Module):
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
x = awq_inference_engine.gemv_forward_cuda(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size
)
return self.down_proj(self.act(x))
......@@ -48,18 +54,24 @@ class QuantLlamaMLP(nn.Module):
self.down_proj = down_proj
def forward(self, x):
return self.down_proj(self.our_llama_mlp(x))
def our_llama_mlp(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size, )
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemm_forward_cuda(
x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8
gate_output = awq_inference_engine.gemv_forward_cuda(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.down_proj.group_size,
)
gate_output = F.silu(gate_output)
up_output = awq_inference_engine.gemm_forward_cuda(
x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8
up_output = awq_inference_engine.gemv_forward_cuda(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.down_proj.group_size,
)
c = gate_output * up_output
c = c.reshape(out_shape)
return c
x = F.silu(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment