Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c361f842
Commit
c361f842
authored
Feb 05, 2023
by
Tim Dettmers
Browse files
Fixed matmul_fp4 transpose.
parent
cfe4705e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
5 deletions
+5
-5
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+2
-2
tests/test_autograd.py
tests/test_autograd.py
+2
-2
tests/test_functional.py
tests/test_functional.py
+1
-1
No files found.
bitsandbytes/autograd/_functions.py
View file @
c361f842
...
...
@@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function):
# 1. Dequantize
# 2. MatmulnN
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_fp4
(
B
,
state
).
to
(
A
.
dtype
),
bias
)
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_fp4
(
B
,
state
).
to
(
A
.
dtype
)
.
t
()
,
bias
)
# 3. Save state
ctx
.
state
=
state
...
...
@@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype_A
))
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype_A
)
.
t
()
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
tests/test_autograd.py
View file @
c361f842
...
...
@@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
()
,
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
()
,
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
...
...
tests/test_functional.py
View file @
c361f842
...
...
@@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_fp4
(
A
,
B_fp4
,
quant_state
=
state
)
bnb
.
matmul_fp4
(
A
,
B_fp4
.
t
()
,
quant_state
=
state
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment