Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6bdb6c35
Commit
6bdb6c35
authored
Feb 13, 2023
by
Tim Dettmers
Browse files
Added fp8 simulation layer.
parent
c9f50506
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
209 additions
and
0 deletions
+209
-0
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+1
-0
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+92
-0
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+16
-0
tests/test_autograd.py
tests/test_autograd.py
+100
-0
No files found.
bitsandbytes/__init__.py
View file @
6bdb6c35
...
@@ -10,6 +10,7 @@ from .autograd._functions import (
...
@@ -10,6 +10,7 @@ from .autograd._functions import (
matmul
,
matmul
,
matmul_cublas
,
matmul_cublas
,
mm_cublas
,
mm_cublas
,
matmul_fp8
)
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
6bdb6c35
...
@@ -390,6 +390,98 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -390,6 +390,98 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
class
MatMulFP8
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
fw_code
=
None
,
bw_code
=
None
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
bias
=
bias
B_shape
=
state
[
1
]
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
1024
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
)
cB
,
state
=
F
.
quantize_blockwise
(
B
,
code
=
fw_code
,
blocksize
=
1024
)
fp8B
=
F
.
dequantize_blockwise
(
cB
,
state
)
output
=
torch
.
nn
.
functional
.
linear
(
fp8A
,
fp8B
)
# 3. Save state
ctx
.
bw_code
=
bw_code
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
(
fp8A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
bias_grad
=
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
)
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
_
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
fp8A
,
B
=
ctx
.
tensors
state
=
ctx
.
state
grad_A
,
grad_B
,
grad_bias
=
None
,
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_ouput
,
code
=
ctx
.
bw_code
,
blocksize
=
1024
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
)
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
fp8out
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
())
if
req_gradB
:
grad_B
=
torch
.
matmul
(
fp8A
.
t
(),
fp8out
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
,
None
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
,
bias
=
None
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
bias
,
fw_code
,
bw_code
)
def
matmul
(
def
matmul
(
A
:
tensor
,
A
:
tensor
,
...
...
bitsandbytes/nn/modules.py
View file @
6bdb6c35
...
@@ -343,3 +343,19 @@ class Linear8bitLt(nn.Linear):
...
@@ -343,3 +343,19 @@ class Linear8bitLt(nn.Linear):
del
self
.
state
.
CxB
del
self
.
state
.
CxB
return
out
return
out
class
LinearFP8
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
F
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
F
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
bias
=
self
.
bias
,
fw_code
=
self
.
fw_code
,
code
=
self
.
bw_code
)
return
out
tests/test_autograd.py
View file @
6bdb6c35
...
@@ -429,3 +429,103 @@ def test_matmullt(
...
@@ -429,3 +429,103 @@ def test_matmullt(
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_fp8
)]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias"
,
values
,
ids
=
names
)
def
test_matmul_fp8
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp8
(
B
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.115
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment