Unverified Commit f8353793 authored by Casper's avatar Casper Committed by GitHub
Browse files

Workaround: illegal memory access (#421)

parent b5db7fcd
...@@ -189,7 +189,8 @@ class WQLinear_GEMVFast(torch.nn.Module): ...@@ -189,7 +189,8 @@ class WQLinear_GEMVFast(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
inputs = x inputs = x
if inputs.numel() / inputs.shape[-1] < 8: batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:
out = awq_v2_ext.gemv_forward_cuda_decode( out = awq_v2_ext.gemv_forward_cuda_decode(
inputs, inputs,
self.qweight, self.qweight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment