Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c6d0a847
Commit
c6d0a847
authored
Nov 08, 2023
by
Ruslan Svirschevski
Browse files
cleanup 0
parent
7d1c9cfe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
15 deletions
+18
-15
bitsandbytes/functional.py
bitsandbytes/functional.py
+8
-7
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+10
-8
No files found.
bitsandbytes/functional.py
View file @
c6d0a847
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
return
out
class
QuantState
:
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
valid_quant_types
=
(
'fp4'
,
'nf4'
)
...
@@ -574,7 +575,6 @@ class QuantState:
...
@@ -574,7 +575,6 @@ class QuantState:
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
self
.
shape
=
shape
self
.
shape
=
shape
...
@@ -585,7 +585,7 @@ class QuantState:
...
@@ -585,7 +585,7 @@ class QuantState:
self
.
offset
=
offset
self
.
offset
=
offset
self
.
state2
=
state2
self
.
state2
=
state2
self
.
nested
=
state2
is
not
None
self
.
nested
=
state2
is
not
None
def
__get_item__
(
self
,
idx
):
def
__get_item__
(
self
,
idx
):
"""
"""
ensures compatibility with older quant state scheme with nested lists.
ensures compatibility with older quant state scheme with nested lists.
...
@@ -598,7 +598,7 @@ class QuantState:
...
@@ -598,7 +598,7 @@ class QuantState:
else
:
else
:
list_repr
=
[
self
.
absmax
,
self
.
shape
,
self
.
dtype
,
self
.
blocksize
,
None
,
self
.
quant_type
]
list_repr
=
[
self
.
absmax
,
self
.
shape
,
self
.
dtype
,
self
.
blocksize
,
None
,
self
.
quant_type
]
return
list_repr
[
idx
]
return
list_repr
[
idx
]
@
classmethod
@
classmethod
def
from_dict
(
cls
,
qs_dict
:
Dict
[
str
,
Any
],
device
:
torch
.
device
)
->
'QuantState'
:
def
from_dict
(
cls
,
qs_dict
:
Dict
[
str
,
Any
],
device
:
torch
.
device
)
->
'QuantState'
:
"""
"""
...
@@ -606,7 +606,7 @@ class QuantState:
...
@@ -606,7 +606,7 @@ class QuantState:
where necessary, convert into strings, torch.dtype, ints, etc.
where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""
"""
...
@@ -615,8 +615,8 @@ class QuantState:
...
@@ -615,8 +615,8 @@ class QuantState:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
elif
len
(
qs_key
)
!=
1
:
raise
ValueError
(
f
"There should be exaclly one quant_state item with key from
{
self
.
valid_qs_type_keys
}
. Detected
{
len
(
qs_
l
ey
)
}
such items"
)
raise
ValueError
(
f
"There should be exaclly one quant_state item with key from
{
cls
.
valid_qs_type_keys
}
. Detected
{
len
(
qs_
k
ey
)
}
such items"
)
# unpacking minor and non-tensor quant state items if necessary
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_key
=
qs_key
[
0
]
...
@@ -673,7 +673,7 @@ class QuantState:
...
@@ -673,7 +673,7 @@ class QuantState:
non_tensor_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
)}
non_tensor_dict
=
{
k
:
v
for
k
,
v
in
qs_dict
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
)}
qs_packed_dict
[
"quant_state."
+
"bitsandbytes__"
+
self
.
quant_type
]
=
pack_dict_to_tensor
(
non_tensor_dict
)
qs_packed_dict
[
"quant_state."
+
"bitsandbytes__"
+
self
.
quant_type
]
=
pack_dict_to_tensor
(
non_tensor_dict
)
return
qs_packed_dict
return
qs_packed_dict
def
to
(
self
,
device
):
def
to
(
self
,
device
):
# make sure the quantization state is on the right device
# make sure the quantization state is on the right device
self
.
absmax
=
self
.
absmax
.
to
(
device
)
self
.
absmax
=
self
.
absmax
.
to
(
device
)
...
@@ -682,6 +682,7 @@ class QuantState:
...
@@ -682,6 +682,7 @@ class QuantState:
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
code
=
self
.
state2
.
code
.
to
(
device
)
self
.
state2
.
code
=
self
.
state2
.
code
.
to
(
device
)
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of size 4096 values.
Quantize tensor A in blocks of size 4096 values.
...
...
bitsandbytes/nn/modules.py
View file @
c6d0a847
...
@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding):
...
@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
Params4bit
(
torch
.
nn
.
Parameter
):
class
Params4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
...
@@ -152,11 +153,11 @@ class Params4bit(torch.nn.Parameter):
...
@@ -152,11 +153,11 @@ class Params4bit(torch.nn.Parameter):
self
.
quant_state
=
quant_state
self
.
quant_state
=
quant_state
self
.
data
=
data
self
.
data
=
data
return
self
return
self
@
classmethod
@
classmethod
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
def
from_state_dict
(
cls
,
state_dict
,
prefix
=
""
,
requires_grad
=
False
):
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
data
=
state_dict
.
pop
(
prefix
.
rstrip
(
'.'
))
# extracting components for QuantState from state_dict
# extracting components for QuantState from state_dict
qs_dict
=
{}
qs_dict
=
{}
for
k
,
v
in
state_dict
.
items
():
for
k
,
v
in
state_dict
.
items
():
...
@@ -164,15 +165,15 @@ class Params4bit(torch.nn.Parameter):
...
@@ -164,15 +165,15 @@ class Params4bit(torch.nn.Parameter):
qs_dict
[
k
]
=
v
qs_dict
[
k
]
=
v
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
state_dict
=
{
k
:
v
for
k
,
v
in
state_dict
.
items
()
if
k
not
in
qs_dict
}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
qs_dict
=
{
k
.
replace
(
prefix
,
''
):
v
for
k
,
v
in
qs_dict
.
items
()}
if
data
.
device
.
type
!=
"cuda"
:
if
data
.
device
.
type
!=
"cuda"
:
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
raise
ValueError
(
f
"`data.device.type` must be 'cuda', detected
{
data
.
device
.
type
}
"
)
cls
.
requires_grad
=
requires_grad
cls
.
requires_grad
=
requires_grad
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
# this attribute can be deprecated - it duplicates same one in quant_state
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
# this attribute can be deprecated - it duplicates quant_state.nested
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
# this attribute can be deprecated - it duplicates same one in quant_state
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
return
self
,
state_dict
...
@@ -207,14 +208,15 @@ class Params4bit(torch.nn.Parameter):
...
@@ -207,14 +208,15 @@ class Params4bit(torch.nn.Parameter):
self
.
quant_state
.
to
(
device
)
self
.
quant_state
.
to
(
device
)
new_param
=
Params4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
new_param
=
Params4bit
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
,
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
,
quant_type
=
self
.
quant_type
)
quant_type
=
self
.
quant_type
)
return
new_param
return
new_param
class
Linear4bit
(
nn
.
Linear
):
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
super
().
__init__
(
input_features
,
output_features
,
bias
,
device
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment