Unverified Commit 9cac5dd1 authored by shadeMe's avatar shadeMe
Browse files

Add `device` parameter to `Linear` subclasses

parent ac5550a0
......@@ -199,8 +199,8 @@ class Params4bit(torch.nn.Parameter):
return new_param
class Linear4bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4',device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
......@@ -223,12 +223,12 @@ class Linear4bit(nn.Linear):
return out
class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
class LinearNF4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True,device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
......@@ -309,8 +309,8 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
super().__init__(input_features, output_features, bias)
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
......@@ -397,8 +397,8 @@ class Linear8bitLt(nn.Linear):
class OutlierAwareLinear(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
def __init__(self, input_features, output_features, bias=True, device=None):
super().__init__(input_features, output_features, bias, device)
self.outlier_dim = None
self.is_quantized = False
......@@ -432,9 +432,10 @@ class SwitchBackLinearBnb(nn.Linear):
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None
):
super().__init__(
input_features, output_features, bias
input_features, output_features, bias, device
)
self.state = bnb.MatmulLtState()
self.index = index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment