Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2d321a75
Unverified
Commit
2d321a75
authored
Jun 19, 2023
by
Tim Dettmers
Committed by
GitHub
Jun 19, 2023
Browse files
Merge pull request #503 from TimDettmers/efficient_8bit_serialize
Make 8-bit serialization more memory-efficient (v2)
parents
ac5550a0
b599fdb1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
36 deletions
+53
-36
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+14
-11
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+39
-25
No files found.
bitsandbytes/autograd/_functions.py
View file @
2d321a75
...
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
...
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return
True
return
True
def
_get_tile_size
(
format
):
assert
format
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
format
}
"
return
(
8
,
32
)
if
format
==
"col_turing"
else
(
32
,
32
)
def
get_tile_inds
(
format
,
device
):
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
format
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
return
get_inverse_transform_indices
(
transform
,
_get_tile_size
(
format
)).
to
(
device
)
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
_tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
_tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -267,20 +280,10 @@ class MatmulLtState:
...
@@ -267,20 +280,10 @@ class MatmulLtState:
self
.
SBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
self
.
CBt
=
None
def
get_tile_size
(
self
):
assert
self
.
formatB
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
@
property
@
property
def
tile_indices
(
self
):
def
tile_indices
(
self
):
if
self
.
_tile_indices
is
None
:
if
self
.
_tile_indices
is
None
:
device
=
self
.
CxB
.
device
self
.
_tile_indices
=
get_tile_inds
(
self
.
formatB
,
self
.
CxB
.
device
)
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
self
.
formatB
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
_tile_indices
=
get_inverse_transform_indices
(
transform
,
self
.
get_tile_size
()).
to
(
device
)
return
self
.
_tile_indices
return
self
.
_tile_indices
...
...
bitsandbytes/nn/modules.py
View file @
2d321a75
...
@@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
...
@@ -10,7 +10,7 @@ from torch import Tensor, device, dtype, nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
import
bitsandbytes.functional
from
bitsandbytes.autograd._functions
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd._functions
import
undo_layout
,
get_tile_inds
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
...
@@ -306,6 +306,17 @@ class Int8Params(torch.nn.Parameter):
...
@@ -306,6 +306,17 @@ class Int8Params(torch.nn.Parameter):
return
new_param
return
new_param
def
maybe_rearrange_weight
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
weight
=
state_dict
.
get
(
f
"
{
prefix
}
weight"
)
if
weight
is
None
:
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
return
weight_format
=
state_dict
.
pop
(
f
"
{
prefix
}
weight_format"
,
"row"
)
if
weight_format
!=
"row"
:
tile_indices
=
get_tile_inds
(
weight_format
,
weight
.
device
)
state_dict
[
f
"
{
prefix
}
weight"
]
=
undo_layout
(
weight
,
tile_indices
)
class
Linear8bitLt
(
nn
.
Linear
):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
...
@@ -322,52 +333,55 @@ class Linear8bitLt(nn.Linear):
...
@@ -322,52 +333,55 @@ class Linear8bitLt(nn.Linear):
self
.
state
.
use_pool
=
True
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
self
.
_register_load_state_dict_pre_hook
(
maybe_rearrange_weight
)
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
None
and
self
.
state
.
CxB
is
not
None
:
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# reorder weight layout back from ampere/turing to row
reorder_layout
=
True
weight_clone
=
self
.
weight
.
data
.
clone
()
else
:
reorder_layout
=
False
try
:
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
if
reorder_layout
:
scb_name
=
"SCB"
self
.
weight
.
data
=
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# case 1: .cuda was called, SCB is in self.weight
param_from_weight
=
getattr
(
self
.
weight
,
scb_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
scb_name
)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered
=
self
.
state
.
CxB
is
not
None
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
key_name
=
prefix
+
f
"
{
scb_name
}
"
weight_name
=
"SCB
"
format_name
=
prefix
+
"weight_format
"
# case 1: .cuda was called, SCB is in self.weight
if
not
self
.
state
.
has_fp16_weights
:
param_from_weight
=
getattr
(
self
.
weight
,
weight_name
)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state
=
getattr
(
self
.
state
,
weight_name
)
key_name
=
prefix
+
f
"
{
weight_name
}
"
if
param_from_weight
is
not
None
:
if
param_from_weight
is
not
None
:
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
destination
[
key_name
]
=
param_from_weight
if
keep_vars
else
param_from_weight
.
detach
()
elif
not
self
.
state
.
has_fp16_weights
and
param_from_state
is
not
None
:
destination
[
format_name
]
=
"row"
elif
param_from_state
is
not
None
and
not
layout_reordered
:
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
finally
:
destination
[
format_name
]
=
"row"
if
reorder_layout
:
elif
param_from_state
is
not
None
:
self
.
weight
.
data
=
weight_clone
destination
[
key_name
]
=
param_from_state
if
keep_vars
else
param_from_state
.
detach
()
destination
[
format_name
]
=
self
.
state
.
formatB
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
missing_keys
,
unexpected_keys
,
error_msgs
):
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
error_msgs
)
for
key
in
unexpected_keys
:
unexpected_copy
=
list
(
unexpected_keys
)
for
key
in
unexpected_copy
:
input_name
=
key
[
len
(
prefix
):]
input_name
=
key
[
len
(
prefix
):]
if
input_name
==
"SCB"
:
if
input_name
==
"SCB"
:
if
self
.
weight
.
SCB
is
None
:
if
self
.
weight
.
SCB
is
None
:
# buffers not yet initialized, can't
call
them directly without
# buffers not yet initialized, can't
access
them directly without
quantizing first
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
raise
RuntimeError
(
"Loading a quantized checkpoint into non-quantized Linear8bitLt is "
"not supported. Please call module.cuda() before module.load_state_dict()"
)
"not supported. Please call module.cuda() before module.load_state_dict()"
)
input_param
=
state_dict
[
key
]
input_param
=
state_dict
[
key
]
self
.
weight
.
SCB
.
copy_
(
input_param
)
self
.
weight
.
SCB
.
copy_
(
input_param
)
if
self
.
state
.
SCB
is
not
None
:
self
.
state
.
SCB
=
self
.
weight
.
SCB
unexpected_keys
.
remove
(
key
)
unexpected_keys
.
remove
(
key
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment