Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
781fcd5b
Commit
781fcd5b
authored
Nov 08, 2023
by
Ruslan Svirschevski
Browse files
partially reverted
76b40a5c
parent
c6d0a847
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
47 deletions
+64
-47
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-9
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+43
-33
tests/test_linear4bit.py
tests/test_linear4bit.py
+8
-5
No files found.
bitsandbytes/functional.py
View file @
781fcd5b
...
...
@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_qs_type_keys
=
[
f
"
quant_state.
bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
valid_qs_type_keys
=
[
f
"bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
...
...
@@ -611,16 +611,19 @@ class QuantState:
"""
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
k
in
cls
.
valid_qs_type_keys
and
isinstance
(
v
,
torch
.
Tensor
)]
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
raise
ValueError
(
f
"There should be exac
l
ly one quant_state item with
key
from
{
cls
.
valid_qs_type_keys
}
.
Detected
{
len
(
qs_key
)
}
such items
"
)
elif
len
(
qs_key
)
!=
1
or
qs_key
[
0
].
split
(
"."
)[
-
1
]
not
in
cls
.
valid_qs_type_keys
:
raise
ValueError
(
f
"There should be exac
t
ly one
`
quant_state
`
item with
ending
from
{
cls
.
valid_qs_type_keys
}
.
\n
Detected
{
qs_key
}
.
"
)
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
.
update
(
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
)))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
assert
set
(
qs_dict
.
keys
()).
issubset
(
cls
.
valid_qs_keys
)
if
'nested_absmax'
in
qs_dict
:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
...
...
@@ -677,6 +680,7 @@ class QuantState:
def
to
(
self
,
device
):
# make sure the quantization state is on the right device
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
offset
=
self
.
offset
.
to
(
device
)
if
self
.
nested
:
self
.
offset
=
self
.
offset
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
...
...
bitsandbytes/nn/modules.py
View file @
781fcd5b
...
...
@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
return
self
@
classmethod
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
def
from_prequantized
(
cls
,
data
,
quantized_stats
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
return
self
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
# data = state_dict.pop(prefix.rstrip('.'))
# extracting components for QuantState from state_dict
qs_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
if
k
.
replace
(
prefix
,
''
).
split
(
'.'
)[
0
]
in
QuantState
.
valid_qs_keys
:
qs_dict
[
k
]
=
v
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
#
# extracting components for QuantState from state_dict
#
qs_dict = {}
#
for k, v in state_dict.items():
#
if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
#
qs_dict[k] = v
#
state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
#
qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
if
data
.
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
#
if data.device.type != "cuda":
#
raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
cls
.
requires_grad
=
requires_grad
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
# this attribute can be deprecated - it duplicates same one in quant_state
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
# this attribute can be deprecated - it duplicates quant_state.nested
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
# this attribute can be deprecated - it duplicates same one in quant_state
#
cls.requires_grad = requires_grad
#
cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
#
cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
#
cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
#
cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
#
self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
#
return self, state_dict
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
...
@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# Note: super()._load_from_state_dict() is not called here intentionally.
if
self
.
bias
is
not
None
:
bias_data
=
state_dict
.
pop
(
prefix
+
"bias"
,
None
)
self
.
bias
.
data
=
bias_data
.
to
(
self
.
bias
.
data
.
device
)
self
.
weight
,
state_dict
=
bnb
.
nn
.
Params4bit
.
from_state_dict
(
state_dict
,
prefix
=
prefix
+
"weight"
+
"."
,
requires_grad
=
False
)
unexpected_keys
.
extend
(
state_dict
.
keys
())
#
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
#
missing_keys, unexpected_keys, error_msgs):
#
# Note: super()._load_from_state_dict() is not called here intentionally.
#
if self.bias is not None:
#
bias_data = state_dict.pop(prefix + "bias", None)
#
self.bias.data = bias_data.to(self.bias.data.device)
#
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
#
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
#
)
#
unexpected_keys.extend(state_dict.keys())
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
...
...
tests/test_linear4bit.py
View file @
781fcd5b
...
...
@@ -7,8 +7,6 @@ import pytest
import
torch
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.nn.modules
import
Linear4bit
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
...
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
...
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device
=
device
,
# TODO create on meta device to save loading time
)
# loading weights from state_dict:
linear_q2
.
load_state_dict
(
sd
)
linear_q2
.
weight
=
weight2
.
to
(
device
)
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
# MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment