Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
726f1470
Unverified
Commit
726f1470
authored
Nov 08, 2023
by
Tim Dettmers
Committed by
GitHub
Nov 08, 2023
Browse files
Merge pull request #864 from poedator/save4_fixes
fixes to recent PR #753
parents
f1ef74f8
54860539
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
59 deletions
+47
-59
bitsandbytes/functional.py
bitsandbytes/functional.py
+19
-15
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+20
-39
tests/test_linear4bit.py
tests/test_linear4bit.py
+8
-5
No files found.
bitsandbytes/functional.py
View file @
726f1470
...
@@ -567,13 +567,13 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -567,13 +567,13 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
return
out
class
QuantState
:
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_qs_type_keys
=
[
f
"quant_state.bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_type_keys
=
[
f
"bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
...
@@ -611,16 +611,19 @@ class QuantState:
...
@@ -611,16 +611,19 @@ class QuantState:
"""
"""
# unpacking tensor with non-tensor components
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
k
in
cls
.
valid_qs_type_keys
and
isinstance
(
v
,
torch
.
Tensor
)]
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
elif
len
(
qs_key
)
!=
1
or
qs_key
[
0
].
split
(
"."
)[
-
1
]
not
in
cls
.
valid_qs_type_keys
:
raise
ValueError
(
f
"There should be exac
l
ly one quant_state item with
key
from
{
self
.
valid_qs_type_keys
}
.
Detected
{
len
(
qs_
l
ey
)
}
such items
"
)
raise
ValueError
(
f
"There should be exac
t
ly one
`
quant_state
`
item with
ending
from
{
cls
.
valid_qs_type_keys
}
.
\n
Detected
{
qs_
k
ey
}
.
"
)
# unpacking minor and non-tensor quant state items if necessary
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_key
=
qs_key
[
0
]
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
.
update
(
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
)))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
assert
set
(
qs_dict
.
keys
()).
issubset
(
cls
.
valid_qs_keys
)
if
'nested_absmax'
in
qs_dict
:
if
'nested_absmax'
in
qs_dict
:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
...
@@ -682,6 +685,7 @@ class QuantState:
...
@@ -682,6 +685,7 @@ class QuantState:
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
code
=
self
.
state2
.
code
.
to
(
device
)
self
.
state2
.
code
=
self
.
state2
.
code
.
to
(
device
)
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of size 4096 values.
Quantize tensor A in blocks of size 4096 values.
...
...
bitsandbytes/nn/modules.py
View file @
726f1470
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
Optional
,
TypeVar
,
Union
,
overload
from
typing
import
Any
,
Dict
,
Optional
,
TypeVar
,
Union
,
overload
import
warnings
import
warnings
import
torch
import
torch
...
@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding):
...
@@ -139,9 +139,10 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
Params4bit
(
torch
.
nn
.
Parameter
):
class
Params4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
)
:
def
__new__
(
cls
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
requires_grad
=
True
,
quant_state
:
QuantState
=
None
,
blocksize
:
int
=
64
,
compress_statistics
:
bool
=
True
,
quant_type
:
str
=
'fp4'
)
->
"Params4bit"
:
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
...
@@ -154,25 +155,14 @@ class Params4bit(torch.nn.Parameter):
...
@@ -154,25 +155,14 @@ class Params4bit(torch.nn.Parameter):
return
self
return
self
@
classmethod
@
classmethod
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
def
from_prequantized
(
cls
,
data
:
torch
.
Tensor
,
quantized_stats
:
Dict
[
str
,
Any
],
requires_grad
:
bool
=
False
,
device
=
'cuda'
,
**
kwargs
)
->
"Params4bit"
:
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
# extracting components for QuantState from state_dict
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
qs_dict
=
{}
self
.
blocksize
=
self
.
quant_state
.
blocksize
for
k
,
v
in
state_dict
.
items
():
self
.
compress_statistics
=
self
.
quant_state
.
nested
if
k
.
replace
(
prefix
,
''
).
split
(
'.'
)[
0
]
in
QuantState
.
valid_qs_keys
:
self
.
quant_type
=
self
.
quant_state
.
quant_type
qs_dict
[
k
]
=
v
return
self
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
if
data
.
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
cls
.
requires_grad
=
requires_grad
,
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
@@ -210,9 +200,10 @@ class Params4bit(torch.nn.Parameter):
...
@@ -210,9 +200,10 @@ class Params4bit(torch.nn.Parameter):
return
new_param
return
new_param
class
Linear4bit
(
nn
.
Linear
):
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
# self.persistent_buffers = [] # TODO consider as way to save quant state
...
@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear):
...
@@ -246,18 +237,6 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# Note: super()._load_from_state_dict() is not called here intentionally.
if
self
.
bias
is
not
None
:
bias_data
=
state_dict
.
pop
(
prefix
+
"bias"
,
None
)
self
.
bias
.
data
=
bias_data
.
to
(
self
.
bias
.
data
.
device
)
self
.
weight
,
state_dict
=
bnb
.
nn
.
Params4bit
.
from_state_dict
(
state_dict
,
prefix
=
prefix
+
"weight"
+
"."
,
requires_grad
=
False
)
unexpected_keys
.
extend
(
state_dict
.
keys
())
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
...
@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear):
...
@@ -280,10 +259,12 @@ class Linear4bit(nn.Linear):
return
out
return
out
class
LinearFP4
(
Linear4bit
):
class
LinearFP4
(
Linear4bit
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'fp4'
,
device
)
class
LinearNF4
(
Linear4bit
):
class
LinearNF4
(
Linear4bit
):
''' Implements the NF4 data type.
''' Implements the NF4 data type.
...
@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
...
@@ -295,7 +276,7 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
'''
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
compute_dtype
,
compress_statistics
,
'nf4'
,
device
)
...
...
tests/test_linear4bit.py
View file @
726f1470
...
@@ -7,8 +7,6 @@ import pytest
...
@@ -7,8 +7,6 @@ import pytest
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.nn.modules
import
Linear4bit
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict:
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
linear
.
in_features
,
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device
=
device
,
# TODO create on meta device to save loading time
device
=
device
,
# TODO create on meta device to save loading time
)
)
# loading weights from state_dict:
# loading weights from state_dict:
linear_q2
.
load_state_dict
(
sd
)
linear_q2
.
weight
=
weight2
.
to
(
device
)
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
# MATCHING
# MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment