Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6cf0f05d
Commit
6cf0f05d
authored
Sep 13, 2023
by
Ruslan Svirschevski
Browse files
rework of non-tensor qs items storage
parent
6a934d4f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
85 additions
and
35 deletions
+85
-35
bitsandbytes/functional.py
bitsandbytes/functional.py
+46
-26
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+5
-9
bitsandbytes/utils.py
bitsandbytes/utils.py
+34
-0
No files found.
bitsandbytes/functional.py
View file @
6cf0f05d
...
...
@@ -13,8 +13,9 @@ from scipy.stats import norm
import
numpy
as
np
from
functools
import
reduce
# Required in Python 3
from
typing
import
Tuple
from
typing
import
Tuple
,
Any
from
torch
import
Tensor
from
bitsandbytes.utils
import
pack_dict_to_tensor
,
unpack_tensor_to_dict
from
.cextension
import
COMPILED_WITH_CUDA
,
lib
...
...
@@ -580,58 +581,77 @@ class QuantState:
self
.
nested
=
state2
is
not
None
@
classmethod
def
from_dict
(
cls
,
q
uant_state
_dict
:
dict
[
str
,
torch
.
Tensor
],
device
:
torch
.
device
)
->
'QuantState'
:
def
from_dict
(
cls
,
q
s
_dict
:
dict
[
str
,
Any
],
device
:
torch
.
device
)
->
'QuantState'
:
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
quant_state_dict may contain item with non-tensor components with key like
`...weight.quant_state.bitsandbytes__[nf4/fp4]`
it is detected with key strored in qs_key, and then unpacked
"""
quant_state_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
quant_state_dict
.
items
()}
if
'quant_state_dict'
in
quant_state_dict
:
quant_state_dict
|=
quant_state_dict
.
pop
(
'quant_state_dict'
)
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
assert
'bitsandbytes__nf4'
in
qs_key
or
'bitsandbytes__fp4'
in
qs_key
,
\
f
"invalid qs_key value
{
qs_key
}
"
qs_dict
|=
unpack_tensor_to_dict
(
qs_dict
.
pop
(
qs_key
))
qs_dict
=
{
k
.
split
(
'.'
)[
-
1
]:
v
for
k
,
v
in
qs_dict
.
items
()}
# strip prefixes
if
'nested_absmax'
in
q
uant_state
_dict
:
offset
=
torch
.
tensor
(
float
(
q
uant_state
_dict
[
'nested_offset'
])).
to
(
device
)
if
'nested_absmax'
in
q
s
_dict
:
offset
=
torch
.
tensor
(
float
(
q
s
_dict
[
'nested_offset'
])).
to
(
device
)
state2
=
cls
(
absmax
=
q
uant_state
_dict
[
'nested_absmax'
].
to
(
device
),
code
=
q
uant_state
_dict
[
'nested_code'
].
to
(
device
),
blocksize
=
int
(
quant_state
_dict
[
'nested_blocksize'
]
)
,
dtype
=
getattr
(
torch
,
q
uant_state
_dict
[
'nested_dtype'
]),
absmax
=
q
s
_dict
[
'nested_absmax'
].
to
(
device
),
code
=
q
s
_dict
[
'nested_code'
].
to
(
device
),
blocksize
=
qs
_dict
[
'nested_blocksize'
],
dtype
=
getattr
(
torch
,
q
s
_dict
[
'nested_dtype'
]),
)
else
:
offset
,
state2
=
None
,
None
quant_state
=
cls
(
absmax
=
q
uant_state
_dict
[
'absmax'
].
to
(
device
),
shape
=
torch
.
Size
(
map
(
int
,
quant_state
_dict
[
'shape'
]
.
split
(
'.'
))
),
dtype
=
getattr
(
torch
,
q
uant_state
_dict
[
'dtype'
]),
blocksize
=
int
(
quant_state
_dict
[
'blocksize'
]
)
,
absmax
=
q
s
_dict
[
'absmax'
].
to
(
device
),
shape
=
torch
.
Size
(
qs
_dict
[
'shape'
]),
dtype
=
getattr
(
torch
,
q
s
_dict
[
'dtype'
]),
blocksize
=
qs
_dict
[
'blocksize'
],
offset
=
offset
,
state2
=
state2
,
quant_type
=
q
uant_state
_dict
[
'quant_type'
],
code
=
q
uant_state
_dict
[
'code'
].
to
(
device
),
quant_type
=
q
s
_dict
[
'quant_type'
],
code
=
q
s
_dict
[
'code'
].
to
(
device
),
)
return
quant_state
def
as_dict
(
self
):
"""dict of tensors and strings to use in serialization via _save_to_state_dict()"""
def
as_dict
(
self
,
packed
=
False
):
"""
returns dict of tensors and strings to use in serialization via _save_to_state_dict()
param: packed -- returns dict[str, torch.Tensor] for state_dict
"""
qs_dict
=
{
'absmax'
:
self
.
absmax
,
'code'
:
self
.
code
,
'shape'
:
','
.
join
(
map
(
str
,
self
.
shape
)
)
,
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch'
),
'blocksize'
:
str
(
self
.
blocksize
)
,
'shape'
:
tuple
(
self
.
shape
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch
.
'
),
'blocksize'
:
self
.
blocksize
,
'quant_type'
:
self
.
quant_type
,
}
if
self
.
nested
:
qs_dict
.
update
({
'nested_absmax'
:
self
.
state2
.
absmax
,
'nested_code'
:
self
.
state2
.
code
,
'nested_offset'
:
f
"
{
self
.
offset
.
item
()
}
"
,
'nested_blocksize'
:
str
(
self
.
state2
.
blocksize
)
,
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch'
),
'nested_offset'
:
self
.
offset
.
item
(),
'nested_blocksize'
:
self
.
state2
.
blocksize
,
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch
.
'
),
})
return
qs_dict
if
not
packed
:
return
qs_dict
qs_packed_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
non_tensor_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
)}
qs_packed_dict
[
"quant_state."
+
"bitsandbytes__"
+
self
.
quant_type
]
=
pack_dict_to_tensor
(
non_tensor_dict
)
return
qs_packed_dict
def
to
(
self
,
device
):
# make sure the quantization state is on the right device
...
...
bitsandbytes/nn/modules.py
View file @
6cf0f05d
...
...
@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data
=
quantized_stats
.
pop
(
'weight'
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_dict
(
q
uant_state
_dict
=
quantized_stats
,
device
=
device
)
self
.
quant_state
=
QuantState
.
from_dict
(
q
s
_dict
=
quantized_stats
,
device
=
device
)
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
...
...
@@ -226,18 +226,14 @@ class Linear4bit(nn.Linear):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""
besides
weight and bias,
fill state_dict with components of quant_state
save
weight and bias,
then
fill state_dict with components of quant_state
"""
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
# saving weight and bias
if
getattr
(
self
.
weight
,
"quant_state"
,
None
)
is
not
None
:
quant_state_dict
=
self
.
weight
.
quant_state
.
as_dict
()
tensor_keys
=
[
k
for
k
,
v
in
quant_state_dict
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)]
for
k
in
tensor_keys
:
destination
[
prefix
+
"weight."
+
k
]
=
quant_state_dict
.
pop
(
k
)
if
keep_vars
else
quant_state_dict
.
pop
(
k
).
detach
()
destination
[
prefix
+
"weight."
+
"quant_state_dict"
]
=
quant_state_dict
destination
[
prefix
+
"weight."
+
"quantization_method"
]
=
"bitsandbytes."
+
quant_state_dict
[
"quant_type"
]
for
k
,
v
in
self
.
weight
.
quant_state
.
as_dict
(
packed
=
True
).
items
():
destination
[
prefix
+
"weight."
+
k
]
=
v
if
keep_vars
else
v
.
detach
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
...
...
bitsandbytes/utils.py
View file @
6cf0f05d
import
json
import
shlex
import
subprocess
import
torch
...
...
@@ -158,3 +159,36 @@ def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_wei
if
func
is
not
None
:
func
(
module
)
return
model
def
pack_dict_to_tensor
(
source_dict
):
"""
Pack a dictionary into a torch tensor for storing quant_state items in state_dict.
Parameters:
- source_dict: The dictionary to be packed.
Returns:
A torch tensor containing the packed data.
"""
json_str
=
json
.
dumps
(
source_dict
)
json_bytes
=
json_str
.
encode
(
'utf-8'
)
tensor_data
=
torch
.
tensor
(
list
(
json_bytes
),
dtype
=
torch
.
uint8
)
return
tensor_data
def
unpack_tensor_to_dict
(
tensor_data
):
"""
Unpack a torch tensor into a Python dictionary.
Parameters:
- tensor_data: The torch tensor containing the packed data.
Returns:
A Python dictionary containing the unpacked data.
"""
json_bytes
=
bytes
(
tensor_data
.
numpy
())
json_str
=
json_bytes
.
decode
(
'utf-8'
)
unpacked_dict
=
json
.
loads
(
json_str
)
return
unpacked_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment