Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cc608c04
Commit
cc608c04
authored
Feb 25, 2023
by
Max Ryabinin
Browse files
Revert the layout if weights were reordered
parent
cd4d904a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
16 deletions
+45
-16
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+45
-16
No files found.
bitsandbytes/nn/modules.py
View file @
cc608c04
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
...
@@ -9,6 +9,8 @@ import torch.nn.functional as F
from
torch
import
Tensor
,
device
,
dtype
,
nn
from
torch
import
Tensor
,
device
,
dtype
,
nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
...
@@ -225,6 +227,30 @@ class Linear8bitLt(nn.Linear):
...
@@ -225,6 +227,30 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
if
reorder_layout
:
if
self
.
state
.
tile_indices
is
None
:
order
,
tile_size
=
self
.
state
.
formatB
,
self
.
state
.
get_tile_size
()
transform
=
lambda
x
:
\
bitsandbytes
.
functional
.
transform
(
x
.
to
(
self
.
weight
.
data
.
device
),
from_order
=
"row"
,
to_order
=
order
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
state
.
tile_indices
=
get_inverse_transform_indices
(
transform
,
tile_size
).
to
(
self
.
state
.
CxB
.
device
)
CB
=
(
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
)
self
.
weight
.
data
=
CB
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
...
@@ -240,6 +266,9 @@ class Linear8bitLt(nn.Linear):
...
@@ -240,6 +266,9 @@ class Linear8bitLt(nn.Linear):
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
if
reorder_layout
:
self
.
weight
.
data
=
weight_clone
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
missing_keys
,
unexpected_keys
,
error_msgs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment