Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
781fcd5b
Commit
781fcd5b
authored
Nov 08, 2023
by
Ruslan Svirschevski
Browse files
partially reverted
76b40a5c
parent
c6d0a847
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
47 deletions
+64
-47
bitsandbytes/functional.py
bitsandbytes/functional.py
+13
-9
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+43
-33
tests/test_linear4bit.py
tests/test_linear4bit.py
+8
-5
No files found.
bitsandbytes/functional.py
View file @
781fcd5b
...
@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -571,9 +571,9 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
class
QuantState
:
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_qs_type_keys
=
[
f
"
quant_state.
bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_type_keys
=
[
f
"bitsandbytes__
{
x
}
"
for
x
in
valid_quant_types
]
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
...
@@ -611,16 +611,19 @@ class QuantState:
...
@@ -611,16 +611,19 @@ class QuantState:
"""
"""
# unpacking tensor with non-tensor components
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
k
in
cls
.
valid_qs_type_keys
and
isinstance
(
v
,
torch
.
Tensor
)]
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
elif
len
(
qs_key
)
!=
1
or
qs_key
[
0
].
split
(
"."
)[
-
1
]
not
in
cls
.
valid_qs_type_keys
:
raise
ValueError
(
f
"There should be exac
l
ly one quant_state item with
key
from
{
cls
.
valid_qs_type_keys
}
.
Detected
{
len
(
qs_key
)
}
such items
"
)
raise
ValueError
(
f
"There should be exac
t
ly one
`
quant_state
`
item with
ending
from
{
cls
.
valid_qs_type_keys
}
.
\n
Detected
{
qs_key
}
.
"
)
# unpacking minor and non-tensor quant state items if necessary
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_key
=
qs_key
[
0
]
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
.
update
(
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
)))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
assert
set
(
qs_dict
.
keys
()).
issubset
(
cls
.
valid_qs_keys
)
if
'nested_absmax'
in
qs_dict
:
if
'nested_absmax'
in
qs_dict
:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
...
@@ -654,7 +657,7 @@ class QuantState:
...
@@ -654,7 +657,7 @@ class QuantState:
'quant_type'
:
self
.
quant_type
,
'quant_type'
:
self
.
quant_type
,
'absmax'
:
self
.
absmax
,
'absmax'
:
self
.
absmax
,
'blocksize'
:
self
.
blocksize
,
'blocksize'
:
self
.
blocksize
,
'quant_map'
:
self
.
code
,
'quant_map'
:
self
.
code
,
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
}
}
...
@@ -677,6 +680,7 @@ class QuantState:
...
@@ -677,6 +680,7 @@ class QuantState:
def
to
(
self
,
device
):
def
to
(
self
,
device
):
# make sure the quantization state is on the right device
# make sure the quantization state is on the right device
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
offset
=
self
.
offset
.
to
(
device
)
if
self
.
nested
:
if
self
.
nested
:
self
.
offset
=
self
.
offset
.
to
(
device
)
self
.
offset
=
self
.
offset
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
...
...
bitsandbytes/nn/modules.py
View file @
781fcd5b
...
@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
...
@@ -155,28 +155,38 @@ class Params4bit(torch.nn.Parameter):
return
self
return
self
@
classmethod
@
classmethod
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
def
from_prequantized
(
cls
,
data
,
quantized_stats
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
# extracting components for QuantState from state_dict
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
qs_dict
=
{}
self
.
blocksize
=
self
.
quant_state
.
blocksize
for
k
,
v
in
state_dict
.
items
():
self
.
compress_statistics
=
self
.
quant_state
.
nested
if
k
.
replace
(
prefix
,
''
).
split
(
'.'
)[
0
]
in
QuantState
.
valid_qs_keys
:
self
.
quant_type
=
self
.
quant_state
.
quant_type
qs_dict
[
k
]
=
v
return
self
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
# @classmethod
# def from_state_dict(cls, state_dict, prefix="", requires_grad=False):
if
data
.
device
.
type
!=
"cuda"
:
# data = state_dict.pop(prefix.rstrip('.'))
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
# # extracting components for QuantState from state_dict
cls
.
requires_grad
=
requires_grad
# qs_dict = {}
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
# for k, v in state_dict.items():
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
# this attribute can be deprecated - it duplicates same one in quant_state
# if k.replace(prefix, '').split('.')[0] in QuantState.valid_qs_keys:
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
# this attribute can be deprecated - it duplicates quant_state.nested
# qs_dict[k] = v
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
# this attribute can be deprecated - it duplicates same one in quant_state
# state_dict = {k: v for k, v in state_dict.items() if k not in qs_dict}
# qs_dict = {k.replace(prefix, ''): v for k, v in qs_dict.items()}
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
# if data.device.type != "cuda":
# raise ValueError(f"`data.device.type` must be 'cuda', detected {data.device.type}")
# cls.requires_grad = requires_grad
# cls.quant_state = QuantState.from_dict(qs_dict=qs_dict, device=data.device)
# cls.blocksize = cls.quant_state.blocksize # this attribute can be deprecated - it duplicates same one in quant_state
# cls.compress_statistics = cls.quant_state.nested # this attribute can be deprecated - it duplicates quant_state.nested
# cls.quant_type = cls.quant_state.quant_type # this attribute can be deprecated - it duplicates same one in quant_state
# self = torch.Tensor._make_subclass(cls, data=data.to(data.device))
# return self, state_dict
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
...
@@ -251,17 +261,17 @@ class Linear4bit(nn.Linear):
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
#
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys
,
unexpected_keys
,
error_msgs
):
#
missing_keys, unexpected_keys, error_msgs):
# Note: super()._load_from_state_dict() is not called here intentionally.
#
# Note: super()._load_from_state_dict() is not called here intentionally.
if
self
.
bias
is
not
None
:
#
if self.bias is not None:
bias_data
=
state_dict
.
pop
(
prefix
+
"bias"
,
None
)
#
bias_data = state_dict.pop(prefix + "bias", None)
self
.
bias
.
data
=
bias_data
.
to
(
self
.
bias
.
data
.
device
)
#
self.bias.data = bias_data.to(self.bias.data.device)
self
.
weight
,
state_dict
=
bnb
.
nn
.
Params4bit
.
from_state_dict
(
#
self.weight, state_dict = bnb.nn.Params4bit.from_state_dict(
state_dict
,
prefix
=
prefix
+
"weight"
+
"."
,
requires_grad
=
False
#
state_dict, prefix=prefix + "weight" + ".", requires_grad=False
)
#
)
unexpected_keys
.
extend
(
state_dict
.
keys
())
#
unexpected_keys.extend(state_dict.keys())
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
...
...
tests/test_linear4bit.py
View file @
781fcd5b
...
@@ -7,8 +7,6 @@ import pytest
...
@@ -7,8 +7,6 @@ import pytest
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.nn.modules
import
Linear4bit
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -41,7 +39,10 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# saving to state_dict:
# saving to state_dict:
sd
=
linear_q
.
state_dict
()
sd
=
linear_q
.
state_dict
()
# restoring from state_dict:
bias_data2
=
sd
.
pop
(
"bias"
,
None
)
weight_data2
=
sd
.
pop
(
"weight"
)
weight2
=
bnb
.
nn
.
Params4bit
.
from_prequantized
(
quantized_stats
=
sd
,
data
=
weight_data2
)
# creating new layer with same params:
# creating new layer with same params:
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear_q2
=
bnb
.
nn
.
Linear4bit
(
linear
.
in_features
,
linear
.
in_features
,
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -53,7 +54,9 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
device
=
device
,
# TODO create on meta device to save loading time
device
=
device
,
# TODO create on meta device to save loading time
)
)
# loading weights from state_dict:
# loading weights from state_dict:
linear_q2
.
load_state_dict
(
sd
)
linear_q2
.
weight
=
weight2
.
to
(
device
)
if
bias
:
linear_q2
.
bias
=
torch
.
nn
.
Parameter
(
bias_data2
)
# MATCHING
# MATCHING
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
a
,
b
=
linear_q
.
weight
,
linear_q2
.
weight
...
@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
...
@@ -61,7 +64,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
assert
a
.
device
==
b
.
device
assert
a
.
device
==
b
.
device
assert
a
.
dtype
==
b
.
dtype
assert
a
.
dtype
==
b
.
dtype
assert
torch
.
equal
(
a
,
b
)
assert
torch
.
equal
(
a
,
b
)
q0
=
a
.
quant_state
q0
=
a
.
quant_state
q1
=
b
.
quant_state
q1
=
b
.
quant_state
for
attr
in
(
'code'
,
'dtype'
,
'blocksize'
,
'absmax'
):
for
attr
in
(
'code'
,
'dtype'
,
'blocksize'
,
'absmax'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment