Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4bd11518
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "33d7e89c42e0fe7b4a277d7a5bae12ba14828dd8"
Commit
4bd11518
authored
May 07, 2023
by
Tim Dettmers
Browse files
Fixed gradient accumulation test.
parent
675baa79
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
10 deletions
+11
-10
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+0
-1
tests/test_modules.py
tests/test_modules.py
+11
-9
No files found.
bitsandbytes/autograd/_functions.py
View file @
4bd11518
...
...
@@ -456,7 +456,6 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
if
req_gradB
:
#grad_B = torch.matmul(grad_output.t(), A)
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
...
...
tests/test_modules.py
View file @
4bd11518
...
...
@@ -332,12 +332,13 @@ def test_linear8bitlt_inference(threshold):
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
l2
[
1
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
1
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam8bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam8bit
(
l2
.
parameters
(),
lr
=
0.001
)
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
opt1
=
bnb
.
optim
.
Adam32bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam32bit
(
l2
.
parameters
(),
lr
=
0.001
)
acc_steps
=
10
...
...
@@ -353,7 +354,6 @@ def test_linear8bitlt_accumulated_gradient():
assert
l1
[
0
].
state
.
CxB
is
not
None
assert
l1
[
1
].
state
.
CxB
is
not
None
print
(
i
)
if
i
>
0
and
i
%
acc_steps
==
0
:
opt1
.
step
()
opt1
.
zero_grad
(
True
)
...
...
@@ -368,9 +368,11 @@ def test_linear8bitlt_accumulated_gradient():
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
l1
[
0
].
bias
.
data
.
copy_
(
l2
[
0
].
bias
.
data
)
l1
[
1
].
bias
.
data
.
copy_
(
l2
[
1
].
bias
.
data
)
else
:
torch
.
testing
.
assert_close
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
torch
.
testing
.
assert_close
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
[
0.0
,
2.0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment