Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c6d0a847
Commit
c6d0a847
authored
Nov 08, 2023
by
Ruslan Svirschevski
Browse files
cleanup 0
parent
7d1c9cfe
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
15 deletions
+18
-15
bitsandbytes/functional.py
bitsandbytes/functional.py
+8
-7
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+10
-8
No files found.
bitsandbytes/functional.py
View file @
c6d0a847
...
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
class
QuantState
:
"""container for quantization state components to work with Params4bit and similar clases"""
valid_quant_types
=
(
'fp4'
,
'nf4'
)
...
...
@@ -574,7 +575,6 @@ class QuantState:
valid_qs_keys
=
[
'absmax'
,
'quant_map'
,
'nested_absmax'
,
'nested_quant_map'
,
'quant_state'
,
'quant_type'
,
'blocksize'
,
'dtype'
,
'shape'
,
'nested_blocksize'
,
'nested_dtype'
,
'nested_offset'
]
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
shape
=
shape
...
...
@@ -615,7 +615,7 @@ class QuantState:
if
not
len
(
qs_key
)
and
'quant_type'
not
in
qs_dict
:
raise
ValueError
(
"Expected packed or unpacked quant_state items, found neither"
)
elif
len
(
qs_key
)
!=
1
:
raise
ValueError
(
f
"There should be exaclly one quant_state item with key from
{
self
.
valid_qs_type_keys
}
. Detected
{
len
(
qs_
l
ey
)
}
such items"
)
raise
ValueError
(
f
"There should be exaclly one quant_state item with key from
{
cls
.
valid_qs_type_keys
}
. Detected
{
len
(
qs_
k
ey
)
}
such items"
)
# unpacking minor and non-tensor quant state items if necessary
if
len
(
qs_key
)
==
1
:
...
...
@@ -682,6 +682,7 @@ class QuantState:
self
.
state2
.
absmax
=
self
.
state2
.
absmax
.
to
(
device
)
self
.
state2
.
code
=
self
.
state2
.
code
.
to
(
device
)
def
quantize_blockwise
(
A
:
Tensor
,
code
:
Tensor
=
None
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
4096
,
nested
=
False
)
->
Tensor
:
"""
Quantize tensor A in blocks of size 4096 values.
...
...
bitsandbytes/nn/modules.py
View file @
c6d0a847
...
...
@@ -139,6 +139,7 @@ class Embedding(torch.nn.Embedding):
return
emb
class
Params4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
...
...
@@ -170,9 +171,9 @@ class Params4bit(torch.nn.Parameter):
cls
.
requires_grad
=
requires_grad
cls
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
qs_dict
,
device
=
data
.
device
)
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
cls
.
blocksize
=
cls
.
quant_state
.
blocksize
# this attribute can be deprecated - it duplicates same one in quant_state
cls
.
compress_statistics
=
cls
.
quant_state
.
nested
# this attribute can be deprecated - it duplicates quant_state.nested
cls
.
quant_type
=
cls
.
quant_state
.
quant_type
# this attribute can be deprecated - it duplicates same one in quant_state
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
=
data
.
to
(
data
.
device
))
return
self
,
state_dict
...
...
@@ -213,6 +214,7 @@ class Params4bit(torch.nn.Parameter):
return
new_param
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
,
device
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment