Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
070f45d2
Commit
070f45d2
authored
Nov 09, 2023
by
Ruslan Svirschevski
Browse files
cleanup commented out deletions
parent
781fcd5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
39 deletions
+5
-39
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+5
-39
No files found.
bitsandbytes/nn/modules.py
View file @
070f45d2
...
@@ -163,30 +163,6 @@ class Params4bit(torch.nn.Parameter):
...
@@ -163,30 +163,6 @@ class Params4bit(torch.nn.Parameter):
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
self
.
quant_type
=
self
.
quant_state
.
quant_type
return
self
return
self
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
# data = state_dict.pop(prefix.rstrip('.'))
# # extracting components for QuantState from state_dict
# qs_dict = {}
# for k, v in state_dict.items():
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
# qs_dict[k] = v
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
# if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
@@ -227,7 +203,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -227,7 +203,7 @@ class Params4bit(torch.nn.Parameter):
class
Linear4bit
(
nn
.
Linear
):
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
# self.persistent_buffers = [] # TODO consider as way to save quant state
...
@@ -261,18 +237,6 @@ class Linear4bit(nn.Linear):
...
@@ -261,18 +237,6 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
# def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# missing_keys, unexpected_keys, error_msgs):
# # Note: super()._load_from_state_dict() is not called here intentionally.
# if self.bias is not None:
# bias_data = state_dict.pop(prefix + "bias", None)
# self.bias.data = bias_data.to(self.bias.data.device)
# self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
# state_dict, prefix=prefix + "weight" + ".", requires_grad=False
# )
# unexpected_keys.extend(state_dict.keys())
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
...
@@ -295,10 +259,12 @@ class Linear4bit(nn.Linear):
...
@@ -295,10 +259,12 @@ class Linear4bit(nn.Linear):
return
out
return
out
class
LinearFP4
(
Linear4bit
):
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
class
LinearNF4
(
Linear4bit
):
class
LinearNF4
(
Linear4bit
):
''' Implements the NF4 data type.
''' Implements the NF4 data type.
...
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
...
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
'''
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment