Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
76b40a5c
Commit
76b40a5c
authored
Oct 25, 2023
by
Ruslan Svirschevski
Browse files
save/load via state_dict now
parent
965fd5d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
60 additions
and
36 deletions
+60
-36
.gitignore
.gitignore
+1
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+19
-14
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+33
-8
tests/test_linear4bit.py
tests/test_linear4bit.py
+7
-14
No files found.
.gitignore
View file @
76b40a5c
...
...
@@ -133,3 +133,4 @@ dmypy.json
dependencies
cuda_build
.vscode/*
bitsandbytes/functional.py
View file @
76b40a5c
...
...
@@ -568,11 +568,17 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
class
QuantState
:
"""container for quantizationstate components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_qs_type_keys
=
[
f
"quant_state.bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'code'
,
'nested_absmax'
,
'nested_code'
,
'quant_state'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
shape
=
shape
self
.
code
=
code
self
.
code
=
code
# TODO consider renaming to `buckets / centroids / scale`
self
.
dtype
=
dtype
self
.
blocksize
=
blocksize
self
.
quant_type
=
quant_type
...
...
@@ -596,26 +602,26 @@ class QuantState:
@
classmethod
def
from_dict
(
cls
,
qs_dict
:
dict
[
str
,
Any
],
device
:
torch
.
device
)
->
'QuantState'
:
"""
unpacks
dict of tensors
into QuantState
unpacks
components of state_dict
into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
q
uant_state_dict may contain item with non-tensor components with key like
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
it
is detected with key strored in qs_key, and then unpacked
q
s_dict: based on state_dict, with only relevant keys, striped of prefixes.
it
em with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
assert
len
(
qs_key
)
==
1
or
not
qs_key
and
'quant_type'
in
qs_dict
,
\
f
"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys:
{
tuple
(
qs_dict
.
keys
())
}
"
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
k
in
cls
.
valid_qs_type_keys
and
isinstance
(
v
,
torch
.
Tensor
)]
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
raise
ValueError
(
f
"There should be exaclly one quant_state item with key from
{
self
.
valid_qs_type_keys
}
. Detected
{
len
(
qs_ley
)
}
such items"
)
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
assert
'bitsandbytes__nf4'
in
qs_key
or
'bitsandbytes__fp4'
in
qs_key
,
\
f
"invalid qs_key value
{
qs_key
}
"
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
if
'nested_absmax'
in
qs_dict
:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
state2
=
cls
(
...
...
@@ -873,7 +879,6 @@ def get_4bit_type(typename, device=None, blocksize=64):
return
data
.
to
(
device
)
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
):
return
quantize_4bit
(
A
,
absmax
,
out
,
blocksize
,
compress_statistics
,
'fp4'
)
...
...
bitsandbytes/nn/modules.py
View file @
76b40a5c
...
...
@@ -154,14 +154,25 @@ class Params4bit(torch.nn.Parameter):
return
self
@
classmethod
def
from_prequantized
(
cls
,
data
,
quantized_stats
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
return
self
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
# extracting components for QuantState from state_dict
qs_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
if
k
.
replace
(
prefix
,
''
).
split
(
'.'
)[
0
]
in
QuantState
.
valid_qs_keys
:
qs_dict
[
k
]
=
v
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
if
data
.
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
cls
.
requires_grad
=
requires_grad
,
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
...
@@ -200,9 +211,11 @@ class Params4bit(torch.nn.Parameter):
return
new_param
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self
.
compute_dtype
=
compute_dtype
self
.
compute_type_is_set
=
False
...
...
@@ -233,6 +246,18 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# Note: super()._load_from_state_dict() is not called here intentionally.
if
self
.
bias
is
not
None
:
bias_data
=
state_dict
.
pop
(
prefix
+
"bias"
,
None
)
self
.
bias
.
data
=
bias_data
.
to
(
self
.
bias
.
data
.
device
)
self
.
weight
,
state_dict
=
bnb
.
nn
.
Params4bit
.
from_state_dict
(
state_dict
,
prefix
=
prefix
+
"weight"
+
"."
,
requires_grad
=
False
)
unexpected_keys
.
extend
(
state_dict
.
keys
())
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
...
...
tests/test_linear4bit.py
View file @
76b40a5c
...
...
@@ -16,7 +16,7 @@ from bitsandbytes.nn.modules import Linear4bit
"quant_type, compress_statistics, bias"
,
list
(
product
([
"nf4"
,
"fp4"
],
[
False
,
True
],
[
False
,
True
])),
)
def
test_linear
4
_s
tate_dict
(
quant_type
,
compress_statistics
,
bias
):
def
test_linear_s
erialization
(
quant_type
,
compress_statistics
,
bias
):
original_dtype
=
torch
.
float16
compute_dtype
=
None
device
=
"cuda"
...
...
@@ -39,16 +39,10 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
if
bias
:
linear_q
.
bias
.
data
=
linear
.
bias
.
data
.
to
(
device
)
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
sd
=
linear_q
.
state_dict
()
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
linear
.
out_features
,
...
...
@@ -56,13 +50,12 @@ def test_linear4_state_dict(quant_type, compress_statistics, bias):
compute_dtype
=
compute_dtype
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
,
device
=
'meta'
,
device
=
device
,
# TODO create on meta device to save loading time
)
linear_q2
.
weight
=
weight2
.
to
(
device
)
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
# loading weights from state_dict:
linear_q2
.
load_state_dict
(
sd
)
#
matching
#
MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
assert
a
.
device
==
b
.
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment