warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore',message='.*inference or training')
defforward(self,x:torch.Tensor):
defforward(self,x:torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
...
@@ -213,6 +236,10 @@ class Linear4bit(nn.Linear):
...
@@ -213,6 +236,10 @@ class Linear4bit(nn.Linear):
ifgetattr(self.weight,'quant_state',None)isNone:
ifgetattr(self.weight,'quant_state',None)isNone:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')