Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cc608c04
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6886e28fd838e9932d9334ba106e8d86714068e3"
Commit
cc608c04
authored
Feb 25, 2023
by
Max Ryabinin
Browse files
Revert the layout if weights were reordered
parent
cd4d904a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
16 deletions
+45
-16
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+45
-16
No files found.
bitsandbytes/nn/modules.py
View file @
cc608c04
...
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
torch
import
Tensor
,
device
,
dtype
,
nn
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
...
...
@@ -210,7 +212,7 @@ class Int8Params(torch.nn.Parameter):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
assert
not
memory_efficient_backward
,
"memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self
.
state
=
bnb
.
MatmulLtState
()
...
...
@@ -225,21 +227,48 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
if
reorder_layout
:
if
self
.
state
.
tile_indices
is
None
:
order
,
tile_size
=
self
.
state
.
formatB
,
self
.
state
.
get_tile_size
()
transform
=
lambda
x
:
\
bitsandbytes
.
functional
.
transform
(
x
.
to
(
self
.
weight
.
data
.
device
),
from_order
=
"row"
,
to_order
=
order
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
state
.
tile_indices
=
get_inverse_transform_indices
(
transform
,
tile_size
).
to
(
self
.
state
.
CxB
.
device
)
CB
=
(
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
)
self
.
weight
.
data
=
CB
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
if
reorder_layout
:
self
.
weight
.
data
=
weight_clone
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment