Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2d321a75
Unverified
Commit
2d321a75
authored
Jun 19, 2023
by
Tim Dettmers
Committed by
GitHub
Jun 19, 2023
Browse files
Merge pull request #503 from TimDettmers/efficient_8bit_serialize
Make 8-bit serialization more memory-efficient (v2)
parents
ac5550a0
b599fdb1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
36 deletions
+53
-36
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+14
-11
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+39
-25
No files found.
bitsandbytes/autograd/_functions.py
View file @
2d321a75
...
...
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return
True
def
_get_tile_size
(
format
):
assert
format
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
format
}
"
return
(
8
,
32
)
if
format
==
"col_turing"
else
(
32
,
32
)
def
get_tile_inds
(
format
,
device
):
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
format
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
return
get_inverse_transform_indices
(
transform
,
_get_tile_size
(
format
)).
to
(
device
)
@
dataclass
class
MatmulLtState
:
_tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -267,20 +280,10 @@ class MatmulLtState:
self
.
SBt
=
None
self
.
CBt
=
None
def
get_tile_size
(
self
):
assert
self
.
formatB
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
@
property
def
tile_indices
(
self
):
if
self
.
_tile_indices
is
None
:
device
=
self
.
CxB
.
device
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
self
.
formatB
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
_tile_indices
=
get_inverse_transform_indices
(
transform
,
self
.
get_tile_size
()).
to
(
device
)
self
.
_tile_indices
=
get_tile_inds
(
self
.
formatB
,
self
.
CxB
.
device
)
return
self
.
_tile_indices
...
...
bitsandbytes/nn/modules.py
View file @
2d321a75
...
...
@@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd._functions
import
undo_layout
,
get_tile_inds
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
...
...
@@ -306,6 +306,17 @@ class Int8Params(torch.nn.Parameter):
return
new_param
def
maybe_rearrange_weight
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
weight
=
state_dict
.
get
(
f
"
{
prefix
}
weight"
)
if
weight
is
None
:
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
return
weight_format
=
state_dict
.
pop
(
f
"
{
prefix
}
weight_format"
,
"row"
)
if
weight_format
!=
"row"
:
tile_indices
=
get_tile_inds
(
weight_format
,
weight
.
device
)
state_dict
[
f
"
{
prefix
}
weight"
]
=
undo_layout
(
weight
,
tile_indices
)
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
...
...
@@ -322,52 +333,55 @@ class Linear8bitLt(nn.Linear):
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
_register_load_state_dict_pre_hook
(
maybe_rearrange_weight
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
if
reorder_layout
:
self
.
weight
.
data
=
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
weight
_name
=
"SCB"
scb
_name
=
"SCB"
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
weight
_name
)
param_from_weight
=
getattr
(
self
.
weight
,
scb
_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
param_from_state
=
getattr
(
self
.
state
,
scb_name
)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered
=
self
.
state
.
CxB
is
not
None
key_name
=
prefix
+
f
"
{
scb_name
}
"
format_name
=
prefix
+
"weight_format"
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
not
self
.
state
.
has_fp16_weights
:
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
format_name
]
=
"row"
elif
param_from_state
is
not
None
and
not
layout_reordered
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
format_name
]
=
"row"
elif
param_from_state
is
not
None
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
if
reorder_layout
:
self
.
weight
.
data
=
weight_clone
destination
[
format_name
]
=
self
.
state
.
formatB
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
for
key
in
unexpected_keys
:
unexpected_copy
=
list
(
unexpected_keys
)
for
key
in
unexpected_copy
:
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
if
self
.
weight
.
SCB
is
None
:
# buffers not yet initialized, can't
call
them directly without
# buffers not yet initialized, can't
access
them directly without
quantizing first
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()"
)
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
if
self
.
state
.
SCB
is
not
None
:
self
.
state
.
SCB
=
self
.
weight
.
SCB
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment