Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
7a117e44
Commit
7a117e44
authored
Sep 13, 2023
by
Ruslan Svirschevski
Browse files
cleanup1
parent
6cf0f05d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
11 deletions
+14
-11
bitsandbytes/functional.py
bitsandbytes/functional.py
+12
-10
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+2
-1
No files found.
bitsandbytes/functional.py
View file @
7a117e44
...
@@ -593,6 +593,8 @@ class QuantState:
...
@@ -593,6 +593,8 @@ class QuantState:
# unpacking tensor with non-tensor components
# unpacking tensor with non-tensor components
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
qs_key
=
[
k
for
k
,
v
in
qs_dict
.
items
()
if
"quant_state"
in
k
and
isinstance
(
v
,
torch
.
Tensor
)]
assert
len
(
qs_key
)
==
1
or
not
qs_key
and
'quant_type'
in
qs_dict
,
\
f
"`qs_dict` must contain packed quant_state items, or be unpacked. Found keys:
{
tuple
(
qs_dict
.
keys
())
}
"
if
len
(
qs_key
)
==
1
:
if
len
(
qs_key
)
==
1
:
qs_key
=
qs_key
[
0
]
qs_key
=
qs_key
[
0
]
assert
'bitsandbytes__nf4'
in
qs_key
or
'bitsandbytes__fp4'
in
qs_key
,
\
assert
'bitsandbytes__nf4'
in
qs_key
or
'bitsandbytes__fp4'
in
qs_key
,
\
...
@@ -605,22 +607,22 @@ class QuantState:
...
@@ -605,22 +607,22 @@ class QuantState:
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
offset
=
torch
.
tensor
(
float
(
qs_dict
[
'nested_offset'
])).
to
(
device
)
state2
=
cls
(
state2
=
cls
(
absmax
=
qs_dict
[
'nested_absmax'
].
to
(
device
),
absmax
=
qs_dict
[
'nested_absmax'
].
to
(
device
),
code
=
qs_dict
[
'nested_code'
].
to
(
device
),
blocksize
=
qs_dict
[
'nested_blocksize'
],
blocksize
=
qs_dict
[
'nested_blocksize'
],
code
=
qs_dict
[
'nested_code'
].
to
(
device
),
dtype
=
getattr
(
torch
,
qs_dict
[
'nested_dtype'
]),
dtype
=
getattr
(
torch
,
qs_dict
[
'nested_dtype'
]),
)
)
else
:
else
:
offset
,
state2
=
None
,
None
offset
,
state2
=
None
,
None
quant_state
=
cls
(
quant_state
=
cls
(
quant_type
=
qs_dict
[
'quant_type'
],
absmax
=
qs_dict
[
'absmax'
].
to
(
device
),
absmax
=
qs_dict
[
'absmax'
].
to
(
device
),
shape
=
torch
.
Size
(
qs_dict
[
'shape'
]),
dtype
=
getattr
(
torch
,
qs_dict
[
'dtype'
]),
blocksize
=
qs_dict
[
'blocksize'
],
blocksize
=
qs_dict
[
'blocksize'
],
code
=
qs_dict
[
'code'
].
to
(
device
),
dtype
=
getattr
(
torch
,
qs_dict
[
'dtype'
]),
shape
=
torch
.
Size
(
qs_dict
[
'shape'
]),
offset
=
offset
,
offset
=
offset
,
state2
=
state2
,
state2
=
state2
,
quant_type
=
qs_dict
[
'quant_type'
],
code
=
qs_dict
[
'code'
].
to
(
device
),
)
)
return
quant_state
return
quant_state
...
@@ -630,20 +632,20 @@ class QuantState:
...
@@ -630,20 +632,20 @@ class QuantState:
param: packed -- returns dict[str, torch.Tensor] for state_dict
param: packed -- returns dict[str, torch.Tensor] for state_dict
"""
"""
qs_dict
=
{
qs_dict
=
{
'quant_type'
:
self
.
quant_type
,
'absmax'
:
self
.
absmax
,
'absmax'
:
self
.
absmax
,
'blocksize'
:
self
.
blocksize
,
'code'
:
self
.
code
,
'code'
:
self
.
code
,
'shape'
:
tuple
(
self
.
shape
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'dtype'
:
str
(
self
.
dtype
).
strip
(
'torch.'
),
'blocksize'
:
self
.
blocksize
,
'shape'
:
tuple
(
self
.
shape
)
if
self
.
nested
else
None
,
'quant_type'
:
self
.
quant_type
,
}
}
if
self
.
nested
:
if
self
.
nested
:
qs_dict
.
update
({
qs_dict
.
update
({
'nested_absmax'
:
self
.
state2
.
absmax
,
'nested_absmax'
:
self
.
state2
.
absmax
,
'nested_code'
:
self
.
state2
.
code
,
'nested_offset'
:
self
.
offset
.
item
(),
'nested_blocksize'
:
self
.
state2
.
blocksize
,
'nested_blocksize'
:
self
.
state2
.
blocksize
,
'nested_code'
:
self
.
state2
.
code
,
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch.'
),
'nested_dtype'
:
str
(
self
.
state2
.
dtype
).
strip
(
'torch.'
),
'nested_offset'
:
self
.
offset
.
item
(),
})
})
if
not
packed
:
if
not
packed
:
return
qs_dict
return
qs_dict
...
...
bitsandbytes/nn/modules.py
View file @
7a117e44
...
@@ -156,7 +156,8 @@ class Params4bit(torch.nn.Parameter):
...
@@ -156,7 +156,8 @@ class Params4bit(torch.nn.Parameter):
@
classmethod
@
classmethod
def
from_prequantized
(
cls
,
quantized_stats
,
data
=
None
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
def
from_prequantized
(
cls
,
quantized_stats
,
data
=
None
,
requires_grad
=
False
,
device
=
'cuda'
,
**
kwargs
):
if
data
is
None
:
if
data
is
None
:
data
=
quantized_stats
.
pop
(
'weight'
)
weight_key
=
[
k
for
k
in
quantized_stats
if
k
.
endswith
(
".weight"
)][
0
]
data
=
quantized_stats
.
pop
(
weight_key
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
.
to
(
device
))
self
.
requires_grad
=
requires_grad
self
.
requires_grad
=
requires_grad
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
self
.
quant_state
=
QuantState
.
from_dict
(
qs_dict
=
quantized_stats
,
device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment