Commit 0ecb8fb4 authored by ved1beta's avatar ved1beta
Browse files

lint

parent 2938c739
...@@ -360,12 +360,12 @@ class Params4bit(torch.nn.Parameter): ...@@ -360,12 +360,12 @@ class Params4bit(torch.nn.Parameter):
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
if func in [torch.chunk, torch.split]: if func in [torch.chunk, torch.split]:
tensor = args[0] tensor = args[0]
result = super().__torch_function__(func, types, args, kwargs) result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple): if isinstance(result, tuple):
return tuple( return tuple(
cls( cls(
...@@ -393,9 +393,9 @@ class Params4bit(torch.nn.Parameter): ...@@ -393,9 +393,9 @@ class Params4bit(torch.nn.Parameter):
module=tensor.module, module=tensor.module,
bnb_quantized=tensor.bnb_quantized, bnb_quantized=tensor.bnb_quantized,
) )
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None: if getattr(module.weight, "quant_state", None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment