Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
070f45d2
Commit
070f45d2
authored
Nov 09, 2023
by
Ruslan Svirschevski
Browse files
cleanup commented out deletions
parent
781fcd5b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
39 deletions
+5
-39
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+5
-39
No files found.
bitsandbytes/nn/modules.py
View file @
070f45d2
...
@@ -164,30 +164,6 @@ class Params4bit(torch.nn.Parameter):
...
@@ -164,30 +164,6 @@ class Params4bit(torch.nn.Parameter):
self
.
quant_type
=
self
.
quant_state
.
quant_type
self
.
quant_type
=
self
.
quant_state
.
quant_type
return
self
return
self
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
# data = state_dict.pop(prefix.rstrip('.'))
# # extracting components for QuantState from state_dict
# qs_dict = {}
# for k, v in state_dict.items():
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
# qs_dict[k] = v
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
# if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w_4bit
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
quant_type
=
self
.
quant_type
)
w_4bit
,
quant_state
=
bnb
.
functional
.
quantize_4bit
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
quant_type
=
self
.
quant_type
)
...
@@ -227,7 +203,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -227,7 +203,7 @@ class Params4bit(torch.nn.Parameter):
class
Linear4bit
(
nn
.
Linear
):
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
# self.persistent_buffers = [] # TODO consider as way to save quant state
...
@@ -261,18 +237,6 @@ class Linear4bit(nn.Linear):
...
@@ -261,18 +237,6 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
# def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# missing_keys, unexpected_keys, error_msgs):
# # Note: super()._load_from_state_dict() is not called here intentionally.
# if self.bias is not None:
# bias_data = state_dict.pop(prefix + "bias", None)
# self.bias.data = bias_data.to(self.bias.data.device)
# self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
# state_dict, prefix=prefix + "weight" + ".", requires_grad=False
# )
# unexpected_keys.extend(state_dict.keys())
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
...
@@ -295,10 +259,12 @@ class Linear4bit(nn.Linear):
...
@@ -295,10 +259,12 @@ class Linear4bit(nn.Linear):
return
out
return
out
class
LinearFP4
(
Linear4bit
):
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
class
LinearNF4
(
Linear4bit
):
class
LinearNF4
(
Linear4bit
):
''' Implements the NF4 data type.
''' Implements the NF4 data type.
...
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
...
@@ -310,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
'''
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment