Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1ccb7bde
Commit
1ccb7bde
authored
Apr 03, 2023
by
Tim Dettmers
Browse files
Fixed ParamsIn4 init; fixed PyTorch 2.0 test failure.
parent
4ea489d3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
18 deletions
+17
-18
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+7
-11
tests/test_functional.py
tests/test_functional.py
+2
-2
tests/test_modules.py
tests/test_modules.py
+8
-5
No files found.
bitsandbytes/nn/modules.py
View file @
1ccb7bde
...
@@ -136,12 +136,14 @@ class Embedding(torch.nn.Embedding):
...
@@ -136,12 +136,14 @@ class Embedding(torch.nn.Embedding):
class
Params4bit
(
torch
.
nn
.
Parameter
):
class
Params4bit
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
cls
.
quant_state
=
None
cls
.
quant_state
=
None
cls
.
blocksize
=
blocksize
cls
.
compress_statistics
=
compress_statistics
cls
.
quant_type
=
quant_type
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
self
.
blocksize
=
blocksize
self
.
compress_statistics
=
compress_statistics
self
.
quant_type
=
quant_type
return
self
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
...
@@ -177,16 +179,10 @@ class Params4bit(torch.nn.Parameter):
...
@@ -177,16 +179,10 @@ class Params4bit(torch.nn.Parameter):
class
Linear4bit
(
nn
.
Linear
):
class
Linear4bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
,
quant_type
=
'fp4'
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
weight
=
Params4bit
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
,
quant_type
=
quant_type
)
self
.
compute_dtype
=
compute_dtype
self
.
compute_dtype
=
compute_dtype
def
init_8bit_state
(
self
):
pass
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
self
.
state
.
is_training
=
self
.
training
# weights are cast automatically as Int8Params, but the bias has to be cast manually
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
self
.
bias
.
data
=
self
.
bias
.
data
.
to
(
x
.
dtype
)
self
.
bias
.
data
=
self
.
bias
.
data
.
to
(
x
.
dtype
)
...
@@ -197,7 +193,7 @@ class Linear4bit(nn.Linear):
...
@@ -197,7 +193,7 @@ class Linear4bit(nn.Linear):
if
self
.
compute_dtype
is
not
None
:
if
self
.
compute_dtype
is
not
None
:
x
=
x
.
to
(
self
.
compute_dtype
)
x
=
x
.
to
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
(
self
.
compute_dtype
)
out
=
bnb
.
matmul_4bit
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
bnb
.
matmul_4bit
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
out
.
to
(
inp_dtype
)
out
=
out
.
to
(
inp_dtype
)
...
...
tests/test_functional.py
View file @
1ccb7bde
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
...
@@ -1798,7 +1798,7 @@ values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
32
iters
=
1
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
"cuda"
).
half
()
...
@@ -2317,7 +2317,7 @@ def test_bench_4bit_dequant(quant_type):
...
@@ -2317,7 +2317,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
00
iters
=
5
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
...
tests/test_modules.py
View file @
1ccb7bde
...
@@ -558,14 +558,17 @@ def test_kbit_backprop(module):
...
@@ -558,14 +558,17 @@ def test_kbit_backprop(module):
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
if
isinstance
(
module
,
bnb
.
nn
.
Linear8bitLt
):
#torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
torch
.
testing
.
assert_allclose
(
grad1
,
grad2
,
atol
=
0.008
,
rtol
=
0.05
)
#torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
torch
.
testing
.
assert_allclose
(
bgrad1
,
bgrad2
,
atol
=
0.008
,
rtol
=
0.05
)
else
:
torch
.
testing
.
assert_allclose
(
grad1
,
grad2
,
atol
=
0.015
,
rtol
=
0.05
)
torch
.
testing
.
assert_allclose
(
bgrad1
,
bgrad2
,
atol
=
0.02
,
rtol
=
0.05
)
ref
.
zero_grad
()
ref
.
zero_grad
()
kbit
.
zero_grad
()
kbit
.
zero_grad
()
assert
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
weight
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
assert
kbit
[
0
].
weight
.
grad
is
None
or
kbit
[
0
].
bias
.
grad
.
sum
().
item
()
==
0
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'out'
,
sum
(
errs1
)
/
len
(
errs1
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'grad'
,
sum
(
errs2
)
/
len
(
errs2
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
print
(
'rel out'
,
sum
(
relerrs1
)
/
len
(
relerrs1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment