Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
781fcd5b
Commit
781fcd5b
authored
Nov 08, 2023
by
Ruslan Svirschevski
Browse files
partially reverted
76b40a5c
parent
c6d0a847
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
47 deletions
+64
-47
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-9
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+43
-33
tests/test_linear4bit.py
tests/test_linear4bit.py
+8
-5
No files found.
bitsandbytes/functional.py
View file @
781fcd5b
...
@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
class
QuantState
:
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_qs_type_keys
=
[
f
"
quant_state.
bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_type_keys
=
[
f
"bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
...
@@ -611,16 +611,19 @@ class QuantState:
...
@@ -611,16 +611,19 @@ class QuantState:
"""
"""
# unpacking tensor with non-tensor components
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
k
in
cls
.
valid_qs_type_keys
and
isinstance
(
v
,
torch
.
Tensor
)]
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
elif
len
(
qs_key
)
!=
1
or
qs_key
[
0
].
split
(
"."
)[
-
1
]
not
in
cls
.
valid_qs_type_keys
:
raise
ValueError
(
f
"There should be exac
l
ly one quant_state item with
key
from
{
cls
.
valid_qs_type_keys
}
.
Detected
{
len
(
qs_key
)
}
such items
"
)
raise
ValueError
(
f
"There should be exac
t
ly one
`
quant_state
`
item with
ending
from
{
cls
.
valid_qs_type_keys
}
.
\n
Detected
{
qs_key
}
.
"
)
# unpacking minor and non-tensor quant state items if necessary
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_key
=
qs_key
[
0
]
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
.
update
(
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
)))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
assert
set
(
qs_dict
.
keys
()).
issubset
(
cls
.
valid_qs_keys
)
if
'nested_absmax'
in
qs_dict
:
if
'nested_absmax'
in
qs_dict
:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
...
@@ -654,7 +657,7 @@ class QuantState:
...
@@ -654,7 +657,7 @@ class QuantState:
'quant_type'
:
self
.
quant_type
,
'quant_type'
:
self
.
quant_type
,
'absmax'
:
self
.
absmax
,
'absmax'
:
self
.
absmax
,
'blocksize'
:
self
.
blocksize
,
'blocksize'
:
self
.
blocksize
,
'quant_map'
:
self
.
code
,
'quant_map'
:
self
.
code
,
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
}
}
...
@@ -677,6 +680,7 @@ class QuantState:
...
@@ -677,6 +680,7 @@ class QuantState:
def
to
(
self
,
device
):
def
to
(
self
,
device
):
# make sure the quantization state is on the right device
# make sure the quantization state is on the right device
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
offset
=
self
.
offset
.
to
(
device
)
if
self
.
nested
:
if
self
.
nested
:
self
.
offset
=
self
.
offset
.
to
(
device
)
self
.
offset
=
self
.
offset
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
...
...
bitsandbytes/nn/modules.py
View file @
781fcd5b
...
@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
...
@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
return
self
return
self
@
classmethod
@
classmethod
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
def
from_prequantized
(
cls
,
data
,
quantized_stats
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
# extracting components for QuantState from state_dict
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
qs_dict
=
{}
self
.
blocksize
=
self
.
quant_state
.
blocksize
for
k
,
v
in
state_dict
.
items
():
self
.
compress_statistics
=
self
.
quant_state
.
nested
if
k
.
replace
(
prefix
,
''
).
split
(
'.'
)[
0
]
in
QuantState
.
valid_qs_keys
:
self
.
quant_type
=
self
.
quant_state
.
quant_type
qs_dict
[
k
]
=
v
return
self
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
if
data
.
device
.
type
!=
"cuda"
:
# data = state_dict.pop(prefix.rstrip('.'))
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
# # extracting components for QuantState from state_dict
cls
.
requires_grad
=
requires_grad
# qs_dict = {}
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
# for k, v in state_dict.items():
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
# this attribute can be deprecated - it duplicates same one in quant_state
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
# this attribute can be deprecated - it duplicates quant_state.nested
# qs_dict[k] = v
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
# this attribute can be deprecated - it duplicates same one in quant_state
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
# if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
...
@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
#
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys
,
unexpected_keys
,
error_msgs
):
#
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
#
# Note: super()._load_from_state_dict() is not called here intentionally.
if
self
.
bias
is
not
None
:
#
if self.bias is not None:
bias_data
=
state_dict
.
pop
(
prefix
+
"bias"
,
None
)
#
bias_data = state_dict.pop(prefix + "bias", None)
self
.
bias
.
data
=
bias_data
.
to
(
self
.
bias
.
data
.
device
)
#
self.bias.data = bias_data.to(self.bias.data.device)
self
.
weight
,
state_dict
=
bnb
.
nn
.
Params4bit
.
from_state_dict
(
#
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict
,
prefix
=
prefix
+
"weight"
+
"."
,
requires_grad
=
False
#
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
#
)
unexpected_keys
.
extend
(
state_dict
.
keys
())
#
unexpected_keys.extend(state_dict.keys())
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
...
...
tests/test_linear4bit.py
View file @
781fcd5b
...
@@ -7,8 +7,6 @@ import pytest
...
@@ -7,8 +7,6 @@ import pytest
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.nn.modules
import
Linear4bit
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict:
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
linear
.
in_features
,
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device
=
device
,
# TODO create on meta device to save loading time
device
=
device
,
# TODO create on meta device to save loading time
)
)
# loading weights from state_dict:
# loading weights from state_dict:
linear_q2
.
load_state_dict
(
sd
)
linear_q2
.
weight
=
weight2
.
to
(
device
)
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
# MATCHING
# MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
...
@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
assert
a
.
device
==
b
.
device
assert
a
.
device
==
b
.
device
assert
a
.
dtype
==
b
.
dtype
assert
a
.
dtype
==
b
.
dtype
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
b
)
q0
=
a
.
quant_state
q0
=
a
.
quant_state
q1
=
b
.
quant_state
q1
=
b
.
quant_state
for
attr
in
(
'code'
,
'dtype'
,
'blocksize'
,
'absmax'
):
for
attr
in
(
'code'
,
'dtype'
,
'blocksize'
,
'absmax'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment