Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c8f564d5
Unverified
Commit
c8f564d5
authored
Dec 06, 2023
by
Tim Dettmers
Committed by
GitHub
Dec 06, 2023
Browse files
Merge pull request #868 from poedator/fix_1108
Fix for 4bit without compress_statistics.
parents
bbbed83a
079d7afe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
11 deletions
+16
-11
bitsandbytes/functional.py
bitsandbytes/functional.py
+5
-4
tests/test_linear4bit.py
tests/test_linear4bit.py
+11
-7
No files found.
bitsandbytes/functional.py
View file @
c8f564d5
...
@@ -642,7 +642,7 @@ class QuantState:
...
@@ -642,7 +642,7 @@ class QuantState:
blocksize
=
qs_dict
[
'blocksize'
],
blocksize
=
qs_dict
[
'blocksize'
],
code
=
qs_dict
[
'quant_map'
].
to
(
device
),
code
=
qs_dict
[
'quant_map'
].
to
(
device
),
dtype
=
getattr
(
torch
,
qs_dict
[
'dtype'
]),
dtype
=
getattr
(
torch
,
qs_dict
[
'dtype'
]),
shape
=
torch
.
Size
(
qs_dict
[
'shape'
]),
shape
=
torch
.
Size
(
qs_dict
[
'shape'
])
if
qs_dict
[
'shape'
]
is
not
None
else
None
,
offset
=
offset
,
offset
=
offset
,
state2
=
state2
,
state2
=
state2
,
)
)
...
@@ -651,7 +651,7 @@ class QuantState:
...
@@ -651,7 +651,7 @@ class QuantState:
def
as_dict
(
self
,
packed
=
False
):
def
as_dict
(
self
,
packed
=
False
):
"""
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict
param: packed -- returns dict[str, torch.Tensor] for state_dict
fit for safetensors saving
"""
"""
qs_dict
=
{
qs_dict
=
{
'quant_type'
:
self
.
quant_type
,
'quant_type'
:
self
.
quant_type
,
...
@@ -659,19 +659,20 @@ class QuantState:
...
@@ -659,19 +659,20 @@ class QuantState:
'blocksize'
:
self
.
blocksize
,
'blocksize'
:
self
.
blocksize
,
'quant_map'
:
self
.
code
,
'quant_map'
:
self
.
code
,
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
'shape'
:
tuple
(
self
.
shape
),
}
}
if
self
.
nested
:
if
self
.
nested
:
qs_dict
.
update
({
qs_dict
.
update
({
'nested_absmax'
:
self
.
state2
.
absmax
,
'nested_absmax'
:
self
.
state2
.
absmax
,
'nested_blocksize'
:
self
.
state2
.
blocksize
,
'nested_blocksize'
:
self
.
state2
.
blocksize
,
'nested_quant_map'
:
self
.
state2
.
code
,
'nested_quant_map'
:
self
.
state2
.
code
.
clone
(),
# un-shared to avoid restoring it after shared tensors are removed by safetensors
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch.'
),
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch.'
),
'nested_offset'
:
self
.
offset
.
item
(),
'nested_offset'
:
self
.
offset
.
item
(),
})
})
if
not
packed
:
if
not
packed
:
return
qs_dict
return
qs_dict
# packed format allows serialization of non-tensor components, critical for saving in safetensors format
qs_packed_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
qs_packed_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
non_tensor_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
)}
non_tensor_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
)}
qs_packed_dict
[
"quant_state."
+
"bitsandbytes__"
+
self
.
quant_type
]
=
pack_dict_to_tensor
(
non_tensor_dict
)
qs_packed_dict
[
"quant_state."
+
"bitsandbytes__"
+
self
.
quant_type
]
=
pack_dict_to_tensor
(
non_tensor_dict
)
...
...
tests/test_linear4bit.py
View file @
c8f564d5
...
@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -20,7 +20,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device
=
"cuda"
device
=
"cuda"
layer_shape
=
(
300
,
400
)
layer_shape
=
(
300
,
400
)
linear
=
torch
.
nn
.
Linear
(
*
layer_shape
,
dtype
=
original_dtype
)
# original layer
linear
=
torch
.
nn
.
Linear
(
*
layer_shape
,
dtype
=
original_dtype
,
device
=
"cpu"
)
# original layer
# Quantizing original layer
# Quantizing original layer
linear_q
=
bnb
.
nn
.
Linear4bit
(
linear_q
=
bnb
.
nn
.
Linear4bit
(
...
@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -30,19 +30,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype
=
compute_dtype
,
compute_dtype
=
compute_dtype
,
compress_statistics
=
compress_statistics
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
,
quant_type
=
quant_type
,
device
=
device
,
device
=
"meta"
,
)
)
new_weight
=
bnb
.
nn
.
Params4bit
(
data
=
linear
.
weight
,
requires_grad
=
False
)
new_weight
=
bnb
.
nn
.
Params4bit
(
data
=
linear
.
weight
,
requires_grad
=
False
)
linear_q
.
weight
=
new_weight
.
to
(
device
)
linear_q
.
weight
=
new_weight
if
bias
:
if
bias
:
linear_q
.
bias
.
data
=
linear
.
bias
.
data
.
to
(
device
)
linear_q
.
bias
=
torch
.
nn
.
Parameter
(
linear
.
bias
)
linear_q
=
linear_q
.
to
(
device
)
# saving to state_dict:
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
# restoring from state_dict:
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
linear
.
in_features
,
...
@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -51,12 +54,13 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
compute_dtype
=
compute_dtype
,
compute_dtype
=
compute_dtype
,
compress_statistics
=
compress_statistics
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
,
quant_type
=
quant_type
,
device
=
device
,
# TODO create on meta device to save loading time
device
=
"meta"
,
)
)
# loading weights from state_dict:
# loading weights from state_dict:
linear_q2
.
weight
=
weight2
.
to
(
device
)
linear_q2
.
weight
=
weight2
if
bias
:
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
linear_q2
=
linear_q2
.
to
(
device
)
# MATCHING
# MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
...
@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -107,6 +111,6 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
state_path_4bit
state_path_4bit
)
)
size_ratio
=
size_4
/
size_orig
size_ratio
=
size_4
/
size_orig
target_compression
=
0.143
if
original_dtype
==
torch
.
float32
else
0.2
85
target_compression
=
0.143
if
original_dtype
==
torch
.
float32
else
0.2
9
# these numbers get lower as weight shape increases
ratio_error_msg
=
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
ratio_error_msg
=
f
"quantized_size
{
size_4
:,
}
is larger on disk than
{
target_compression
:.
2
%
}
of original size
{
size_orig
:,
}
"
assert
size_ratio
<
target_compression
,
ratio_error_msg
assert
size_ratio
<
target_compression
,
ratio_error_msg
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment