Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ca323658
Commit
ca323658
authored
Feb 13, 2023
by
Tim Dettmers
Browse files
Added forward/backward tests; removed bias.
parent
6bdb6c35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
53 deletions
+48
-53
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+14
-22
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-1
tests/test_autograd.py
tests/test_autograd.py
+31
-30
No files found.
bitsandbytes/autograd/_functions.py
View file @
ca323658
...
@@ -395,15 +395,14 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -395,15 +395,14 @@ class MatMulFP8(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
fw_code
=
None
,
bw_code
=
None
):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
):
# default of pytorch behavior if inputs are empty
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
B
=
B
ctx
.
bias
=
bias
B_shape
=
B
.
shape
B_shape
=
state
[
1
]
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
else
:
...
@@ -414,17 +413,17 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -414,17 +413,17 @@ class MatMulFP8(torch.autograd.Function):
# 2. MatmulnN
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
)
.
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize_blockwise
(
B
,
code
=
fw_code
,
blocksize
=
1024
)
cB
,
state
=
F
.
quantize_blockwise
(
B
,
code
=
fw_code
,
blocksize
=
1024
)
fp8B
=
F
.
dequantize_blockwise
(
cB
,
state
)
fp8B
=
F
.
dequantize_blockwise
(
cB
,
state
)
.
to
(
B
.
dtype
)
output
=
torch
.
nn
.
functional
.
linear
(
fp8A
,
fp8B
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# 3. Save state
# 3. Save state
ctx
.
bw_code
=
bw_code
ctx
.
bw_code
=
bw_code
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
(
fp8A
,
fp8B
)
ctx
.
tensors
=
(
fp8A
,
fp8B
)
...
@@ -436,21 +435,15 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -436,21 +435,15 @@ class MatMulFP8(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
bias_grad
=
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
)
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
_
,
_
,
req_gradB
ias
,
_
=
ctx
.
needs_input_grad
req_gradA
,
req_gradB
,
_
,
_
,
_
=
ctx
.
needs_input_grad
fp8A
,
B
=
ctx
.
tensors
fp8A
,
B
=
ctx
.
tensors
state
=
ctx
.
state
grad_A
,
grad_B
,
grad_bias
=
None
,
None
,
None
grad_A
,
grad_B
=
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_ouput
,
code
=
ctx
.
bw_code
,
blocksize
=
1024
)
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
1024
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
).
to
(
grad_output
.
dtype
)
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
fp8out
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
...
@@ -461,7 +454,7 @@ class MatMulFP8(torch.autograd.Function):
...
@@ -461,7 +454,7 @@ class MatMulFP8(torch.autograd.Function):
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
())
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
())
if
req_gradB
:
grad_B
=
torch
.
matmul
(
fp8A
.
t
(),
fp8out
)
if
req_gradB
:
grad_B
=
torch
.
matmul
(
fp8A
.
t
(),
fp8out
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
def
matmul
(
def
matmul
(
...
@@ -478,9 +471,8 @@ def matmul(
...
@@ -478,9 +471,8 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bias
=
None
):
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
):
assert
quant_state
is
not
None
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
)
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
bias
,
fw_code
,
bw_code
)
def
matmul
(
def
matmul
(
...
...
bitsandbytes/nn/modules.py
View file @
ca323658
...
@@ -355,7 +355,9 @@ class LinearFP8(nn.Linear):
...
@@ -355,7 +355,9 @@ class LinearFP8(nn.Linear):
self
.
bw_code
=
F
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
bw_code
=
F
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
F
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
F
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
bias
=
self
.
bias
,
fw_code
=
self
.
fw_code
,
code
=
self
.
bw_code
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
code
=
self
.
bw_code
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
return
out
tests/test_autograd.py
View file @
ca323658
...
@@ -456,18 +456,16 @@ transpose = [(False, True), (False, False)]
...
@@ -456,18 +456,16 @@ transpose = [(False, True), (False, False)]
str_transpose
=
[
"NT"
,
"NN"
]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose
, has_bias
"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
):
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
for
i
in
range
(
k
):
# normal multiply
# normal multiply
...
@@ -475,32 +473,24 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -475,32 +473,24 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp8
(
B
)
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
A
.
device
)
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
A
.
device
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
2
.
t
(),
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B
.
t
(),
fw_code
,
bw_code
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B
,
fw_code
,
bw_code
)
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.115
assert
err
<
0.20
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -510,9 +500,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -510,9 +500,6 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradB1
=
B
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
loss_torch
.
backward
()
...
@@ -520,12 +507,26 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -520,12 +507,26 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
gradB2
=
B
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
if
req_grad
[
1
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
grad_err
=
(
gradB1
-
gradB2
).
abs
().
mean
()
assert
grad_err
.
item
()
<
0.003
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment