Unverified Commit f8353793 authored by Casper's avatar Casper Committed by GitHub
Browse files

Workaround: illegal memory access (#421)

parent b5db7fcd
......@@ -189,7 +189,8 @@ class WQLinear_GEMVFast(torch.nn.Module):
@torch.no_grad()
def forward(self, x):
inputs = x
if inputs.numel() / inputs.shape[-1] < 8:
batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:
out = awq_v2_ext.gemv_forward_cuda_decode(
inputs,
self.qweight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment