Commit 0ecb8fb4 authored by ved1beta's avatar ved1beta
Browse files

lint

parent 2938c739
......@@ -360,12 +360,12 @@ class Params4bit(torch.nn.Parameter):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = super().__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
......@@ -393,9 +393,9 @@ class Params4bit(torch.nn.Parameter):
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return super().__torch_function__(func, types, args, kwargs)
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment