Commit 6c31a5fe authored by Artidoro Pagnoni's avatar Artidoro Pagnoni
Browse files

t5 model fix

parent 9851a10b
...@@ -190,10 +190,10 @@ class LinearFP4(nn.Linear): ...@@ -190,10 +190,10 @@ class LinearFP4(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
inp_dtype = x.dtype inp_dtype = x.dtype
x = x.to(torch.float16) x = x.to(torch.float16)
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state) bias = None if self.bias is None else self.bias.half()
out = bnb.matmul_fp4(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype) out = out.to(inp_dtype)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment