Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
48b3e770
"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "03df281275ad3fcb732a41ab1638c2e89afddb25"
Commit
48b3e770
authored
Sep 12, 2023
by
Ruslan Svirschevski
Browse files
some renaming
parent
5bcc1ddc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
16 deletions
+20
-16
bitsandbytes/functional.py
bitsandbytes/functional.py
+19
-15
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-1
No files found.
bitsandbytes/functional.py
View file @
48b3e770
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
...
@@ -567,6 +567,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
return
out
return
out
class
QuantState
:
class
QuantState
:
"""container for quantizationstate components to work with Params4bit and similar clases"""
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
def
__init__
(
self
,
absmax
,
shape
=
None
,
code
=
None
,
blocksize
=
None
,
quant_type
=
None
,
dtype
=
None
,
offset
=
None
,
state2
=
None
):
self
.
absmax
=
absmax
self
.
absmax
=
absmax
self
.
shape
=
shape
self
.
shape
=
shape
...
@@ -579,32 +580,35 @@ class QuantState:
...
@@ -579,32 +580,35 @@ class QuantState:
self
.
nested
=
state2
is
not
None
self
.
nested
=
state2
is
not
None
@
classmethod
@
classmethod
def
from_kwargs
(
cls
,
kwargs
,
device
):
def
from_dict
(
cls
,
quant_state_dict
:
dict
[
str
,
torch
.
Tensor
],
device
:
torch
.
device
)
->
'QuantState'
:
"""
unpacks dict of tensors into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
"""
tensor2str
=
lambda
xx
:
''
.
join
([
chr
(
x
)
for
x
in
xx
]).
strip
(
'.'
)
tensor2str
=
lambda
xx
:
''
.
join
([
chr
(
x
)
for
x
in
xx
]).
strip
(
'.'
)
kwargs
=
{
k
.
split
(
'.'
)[
-
1
]
:
v
for
k
,
v
in
kwargs
.
items
()}
quant_state_dict
=
{
k
.
split
(
'.'
)[
-
1
]
:
v
for
k
,
v
in
quant_state_dict
.
items
()}
if
'nested_absmax'
in
kwargs
:
if
'nested_absmax'
in
quant_state_dict
:
offset
=
kwargs
[
'nested_offset'
]
offset
=
quant_state_dict
[
'nested_offset'
]
state2
=
cls
(
state2
=
cls
(
absmax
=
kwargs
[
'nested_absmax'
].
to
(
device
),
absmax
=
quant_state_dict
[
'nested_absmax'
].
to
(
device
),
code
=
kwargs
[
'nested_code'
].
to
(
device
),
code
=
quant_state_dict
[
'nested_code'
].
to
(
device
),
blocksize
=
kwargs
[
'nested_blocksize'
].
item
(),
blocksize
=
quant_state_dict
[
'nested_blocksize'
].
item
(),
dtype
=
getattr
(
torch
,
tensor2str
(
kwargs
[
'nested_dtype'
])),
dtype
=
getattr
(
torch
,
tensor2str
(
quant_state_dict
[
'nested_dtype'
])),
)
)
else
:
else
:
offset
,
state2
=
None
,
None
offset
,
state2
=
None
,
None
quant_state
=
cls
(
quant_state
=
cls
(
absmax
=
kwargs
[
'absmax'
].
to
(
device
),
absmax
=
quant_state_dict
[
'absmax'
].
to
(
device
),
shape
=
torch
.
Size
(
kwargs
[
'shape'
]),
shape
=
torch
.
Size
(
quant_state_dict
[
'shape'
]),
dtype
=
getattr
(
torch
,
tensor2str
(
kwargs
[
'dtype'
])),
dtype
=
getattr
(
torch
,
tensor2str
(
quant_state_dict
[
'dtype'
])),
blocksize
=
kwargs
[
'blocksize'
].
item
(),
blocksize
=
quant_state_dict
[
'blocksize'
].
item
(),
offset
=
offset
,
offset
=
offset
,
state2
=
state2
,
state2
=
state2
,
quant_type
=
tensor2str
(
kwargs
[
'quant_type'
]),
quant_type
=
tensor2str
(
quant_state_dict
[
'quant_type'
]),
code
=
kwargs
[
'code'
].
to
(
device
),
code
=
quant_state_dict
[
'code'
].
to
(
device
),
)
)
return
quant_state
return
quant_state
...
...
bitsandbytes/nn/modules.py
View file @
48b3e770
...
@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
...
@@ -159,7 +159,7 @@ class Params4bit(torch.nn.Parameter):
data
=
quantized_stats
.
pop
(
'weight'
)
data
=
quantized_stats
.
pop
(
'weight'
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_
kwargs
(
kwargs
=
quantized_stats
,
device
=
device
)
self
.
quant_state
=
QuantState
.
from_
dict
(
quant_state_dict
=
quantized_stats
,
device
=
device
)
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
blocksize
=
self
.
quant_state
.
blocksize
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
compress_statistics
=
self
.
quant_state
.
nested
self
.
quant_type
=
self
.
quant_state
.
quant_type
self
.
quant_type
=
self
.
quant_state
.
quant_type
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment