Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
13c0a4dc
Commit
13c0a4dc
authored
Feb 04, 2023
by
Tim Dettmers
Browse files
Backward matmul_fp4 passes.
parent
160a8358
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
23 deletions
+8
-23
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+8
-7
tests/test_autograd.py
tests/test_autograd.py
+0
-16
No files found.
bitsandbytes/autograd/_functions.py
View file @
13c0a4dc
...
@@ -503,11 +503,9 @@ class MatMulFP4(torch.autograd.Function):
...
@@ -503,11 +503,9 @@ class MatMulFP4(torch.autograd.Function):
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
A
ctx
.
tensors
=
(
A
,
B
)
else
:
else
:
ctx
.
tensors
=
[
None
,
None
]
ctx
.
tensors
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
return
output
return
output
...
@@ -517,10 +515,12 @@ class MatMulFP4(torch.autograd.Function):
...
@@ -517,10 +515,12 @@ class MatMulFP4(torch.autograd.Function):
bias_grad
=
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
)
bias_grad
=
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
)
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
req_gradA
,
_
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
A
=
ctx
.
tensors
A
,
B
=
ctx
.
tensors
state
=
ctx
.
state
state
=
ctx
.
state
grad_A
,
grad_B
,
grad_bias
=
None
,
None
,
None
if
req_gradBias
:
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
...
@@ -529,7 +529,8 @@ class MatMulFP4(torch.autograd.Function):
...
@@ -529,7 +529,8 @@ class MatMulFP4(torch.autograd.Function):
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
if
req_gradB
:
grad_B
=
torch
.
matmul
(
grad_output
.
t
(),
A
)
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype_A
))
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype_A
))
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
tests/test_autograd.py
View file @
13c0a4dc
...
@@ -480,7 +480,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -480,7 +480,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
B
)
...
@@ -526,21 +525,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -526,21 +525,6 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment