Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
48b3e770
Commit
48b3e770
authored
Sep 12, 2023
by
Ruslan Svirschevski
Browse files
some renaming
parent
5bcc1ddc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
16 deletions
+20
-16
bitsandbytes/functional.py
bitsandbytes/functional.py
+19
-15
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-1
No files found.
bitsandbytes/functional.py
View file @
48b3e770
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
return
out
class
QuantState
:
class
QuantState
:
"""container for quantizationstate components to work with Params4bit and similar clases"""
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
self
.
shape
=
shape
self
.
shape
=
shape
...
@@ -579,32 +580,35 @@ class QuantState:
...
@@ -579,32 +580,35 @@ class QuantState:
self
.
nested
=
state2
is
not
None
self
.
nested
=
state2
is
not
None
@
classmethod
@
classmethod
def
from_kwargs
(
cls
,
kwargs
,
device
):
def
from_dict
(
cls
,
quant_state_dict
:
dict
[
str
,
torch
.
Tensor
],
device
:
torch
.
device
)
->
'QuantState'
:
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
"""
tensor2str
=
lambda
xx
:
''
.
join
([
chr
(
x
)
for
x
in
xx
]).
strip
(
'.'
)
tensor2str
=
lambda
xx
:
''
.
join
([
chr
(
x
)
for
x
in
xx
]).
strip
(
'.'
)
kwargs
=
{
k
.
split
(
'.'
)[
-
1
]
:
v
for
k
,
v
in
kwargs
.
items
()}
quant_state_dict
=
{
k
.
split
(
'.'
)[
-
1
]
:
v
for
k
,
v
in
quant_state_dict
.
items
()}
if
'nested_absmax'
in
kwargs
:
if
'nested_absmax'
in
quant_state_dict
:
offset
=
kwargs
[
'nested_offset'
]
offset
=
quant_state_dict
[
'nested_offset'
]
state2
=
cls
(
state2
=
cls
(
absmax
=
kwargs
[
'nested_absmax'
].
to
(
device
),
absmax
=
quant_state_dict
[
'nested_absmax'
].
to
(
device
),
code
=
kwargs
[
'nested_code'
].
to
(
device
),
code
=
quant_state_dict
[
'nested_code'
].
to
(
device
),
blocksize
=
kwargs
[
'nested_blocksize'
].
item
(),
blocksize
=
quant_state_dict
[
'nested_blocksize'
].
item
(),
dtype
=
getattr
(
torch
,
tensor2str
(
kwargs
[
'nested_dtype'
])),
dtype
=
getattr
(
torch
,
tensor2str
(
quant_state_dict
[
'nested_dtype'
])),
)
)
else
:
else
:
offset
,
state2
=
None
,
None
offset
,
state2
=
None
,
None
quant_state
=
cls
(
quant_state
=
cls
(
absmax
=
kwargs
[
'absmax'
].
to
(
device
),
absmax
=
quant_state_dict
[
'absmax'
].
to
(
device
),
shape
=
torch
.
Size
(
kwargs
[
'shape'
]),
shape
=
torch
.
Size
(
quant_state_dict
[
'shape'
]),
dtype
=
getattr
(
torch
,
tensor2str
(
kwargs
[
'dtype'
])),
dtype
=
getattr
(
torch
,
tensor2str
(
quant_state_dict
[
'dtype'
])),
blocksize
=
kwargs
[
'blocksize'
].
item
(),
blocksize
=
quant_state_dict
[
'blocksize'
].
item
(),
offset
=
offset
,
offset
=
offset
,
state2
=
state2
,
state2
=
state2
,
quant_type
=
tensor2str
(
kwargs
[
'quant_type'
]),
quant_type
=
tensor2str
(
quant_state_dict
[
'quant_type'
]),
code
=
kwargs
[
'code'
].
to
(
device
),
code
=
quant_state_dict
[
'code'
].
to
(
device
),
)
)
return
quant_state
return
quant_state
...
...
bitsandbytes/nn/modules.py
View file @
48b3e770
...
@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data
=
quantized_stats
.
pop
(
'weight'
)
data
=
quantized_stats
.
pop
(
'weight'
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_
kwargs
(
kwargs
=
quantized_stats
,
device
=
device
)
self
.
quant_state
=
QuantState
.
from_
dict
(
quant_state_dict
=
quantized_stats
,
device
=
device
)
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
self
.
quant_type
=
self
.
quant_state
.
quant_type
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment