Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f734076e
"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "8a0d7b36f7821fe55175f0d4e3ca6299b3817a6c"
Commit
f734076e
authored
Jun 09, 2023
by
Max Ryabinin
Browse files
Improve memory efficiency of 8-bit serialization
parent
4fb37d45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
26 deletions
+33
-26
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+33
-26
No files found.
bitsandbytes/nn/modules.py
View file @
f734076e
...
@@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
...
@@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd._functions
import
undo_layout
,
get_tile_inds
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
...
@@ -306,7 +306,6 @@ class Int8Params(torch.nn.Parameter):
...
@@ -306,7 +306,6 @@ class Int8Params(torch.nn.Parameter):
return
new_param
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
...
@@ -324,50 +323,58 @@ class Linear8bitLt(nn.Linear):
...
@@ -324,50 +323,58 @@ class Linear8bitLt(nn.Linear):
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
if
reorder_layout
:
self
.
weight
.
data
=
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
scb_name
=
"SCB"
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
# case 1: .cuda was called, SCB is in self.weight
weight_name
=
"SCB"
param_from_weight
=
getattr
(
self
.
weight
,
scb_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
scb_name
)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered
=
self
.
state
.
CxB
is
not
None
# case 1: .cuda was called, SCB is in self.weight
key_name
=
prefix
+
f
"
{
scb_name
}
"
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
format_name
=
prefix
+
"weight_format"
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
not
self
.
state
.
has_fp16_weights
:
if
param_from_weight
is
not
None
:
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
format_name
]
=
"row"
elif
param_from_state
is
not
None
and
not
layout_reordered
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
format_name
]
=
"row"
elif
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
destination
[
format_name
]
=
self
.
state
.
formatB
if
reorder_layout
:
self
.
weight
.
data
=
weight_clone
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
error_msgs
)
for
key
in
unexpected_keys
:
unexpected_copy
=
list
(
unexpected_keys
)
for
key
in
unexpected_copy
:
input_name
=
key
[
len
(
prefix
):]
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
if
input_name
==
"SCB"
:
if
self
.
weight
.
SCB
is
None
:
if
self
.
weight
.
SCB
is
None
:
# buffers not yet initialized, can't
call
them directly without
# buffers not yet initialized, can't
access
them directly without
quantizing first
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()"
)
"not supported. Please call module.cuda() before module.load_state_dict()"
)
input_param
=
state_dict
[
key
]
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
self
.
weight
.
SCB
.
copy_
(
input_param
)
if
self
.
state
.
SCB
is
not
None
:
self
.
state
.
SCB
=
self
.
weight
.
SCB
unexpected_keys
.
remove
(
key
)
if
input_name
==
"weight_format"
:
input_param
=
state_dict
[
key
]
if
input_param
!=
"row"
:
tile_indices
=
get_tile_inds
(
input_param
,
self
.
weight
.
device
)
self
.
weight
.
data
=
self
.
weight
.
CB
=
undo_layout
(
self
.
weight
.
data
,
tile_indices
)
unexpected_keys
.
remove
(
key
)
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment