Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
160a8358
Commit
160a8358
authored
Feb 04, 2023
by
Tim Dettmers
Browse files
Forward matmul_fp4 tests pass.
parent
3ac5840c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
254 additions
and
23 deletions
+254
-23
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+1
-0
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+66
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+8
-7
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+62
-0
tests/test_autograd.py
tests/test_autograd.py
+115
-0
tests/test_functional.py
tests/test_functional.py
+2
-15
No files found.
bitsandbytes/__init__.py
View file @
160a8358
...
@@ -10,6 +10,7 @@ from .autograd._functions import (
...
@@ -10,6 +10,7 @@ from .autograd._functions import (
matmul
,
matmul
,
matmul_cublas
,
matmul_cublas
,
mm_cublas
,
mm_cublas
,
matmul_fp4
)
)
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
from
.nn
import
modules
from
.nn
import
modules
...
...
bitsandbytes/autograd/_functions.py
View file @
160a8358
...
@@ -2,7 +2,7 @@ import operator
...
@@ -2,7 +2,7 @@ import operator
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
# Required in Python 3
from
functools
import
reduce
# Required in Python 3
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
List
import
torch
import
torch
...
@@ -474,6 +474,67 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -474,6 +474,67 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
class
MatMulFP4
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
None
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
bias
=
bias
B_shape
=
state
[
1
]
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. Matmul
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_fp4
(
B
,
state
).
to
(
A
.
dtype
),
bias
)
# 3. Save state
ctx
.
state
=
state
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
ctx
.
tensors
=
A
else
:
ctx
.
tensors
=
[
None
,
None
]
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
bias_grad
=
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
)
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input_grad
A
=
ctx
.
tensors
state
=
ctx
.
state
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
if
req_gradB
:
grad_B
=
torch
.
matmul
(
grad_output
.
t
(),
A
)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype_A
))
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
def
matmul
(
def
matmul
(
A
:
tensor
,
A
:
tensor
,
B
:
tensor
,
B
:
tensor
,
...
@@ -486,3 +547,7 @@ def matmul(
...
@@ -486,3 +547,7 @@ def matmul(
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp4
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
quant_state
:
List
=
None
,
bias
=
None
):
return
MatMulFP4
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
160a8358
...
@@ -626,7 +626,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
...
@@ -626,7 +626,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
-------
-------
torch.Tensor:
torch.Tensor:
The 8-bit tensor with packed 4-bit values.
The 8-bit tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype):
tuple(torch.Tensor, torch.Size, torch.dtype
, int
):
The quantization state to undo the quantization.
The quantization state to undo the quantization.
"""
"""
if
A
.
device
.
type
!=
'cuda'
:
if
A
.
device
.
type
!=
'cuda'
:
...
@@ -640,10 +640,10 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
...
@@ -640,10 +640,10 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
blocks
+=
1
if
n
%
blocksize
>
0
else
0
blocks
+=
1
if
n
%
blocksize
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
)
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,
1
),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
...
@@ -692,7 +692,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -692,7 +692,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
shape
=
out
.
shape
shape
=
out
.
shape
dtype
=
out
.
dtype
dtype
=
out
.
dtype
else
:
else
:
absmax
,
shape
,
dtype
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
=
quant_state
if
out
is
None
:
if
out
is
None
:
...
@@ -700,6 +700,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -700,6 +700,7 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
n
=
out
.
numel
()
n
=
out
.
numel
()
device
=
pre_call
(
A
.
device
)
device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
absmax
,
out
])
is_on_gpu
([
A
,
absmax
,
out
])
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
...
@@ -710,9 +711,9 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -710,9 +711,9 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
return
out
is_transposed
=
(
True
if
A
.
shape
[
0
]
==
1
else
False
)
if
is_transposed
:
return
out
.
t
()
else
:
return
out
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize
(
A
:
Tensor
,
code
:
Tensor
=
None
,
out
:
Tensor
=
None
)
->
Tensor
:
...
...
bitsandbytes/nn/modules.py
View file @
160a8358
...
@@ -133,6 +133,67 @@ class Embedding(torch.nn.Embedding):
...
@@ -133,6 +133,67 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
FP4Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
):
cls
.
quant_state
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
w
)
self
.
data
=
w_fp4
self
.
quant_state
=
quant_state
return
self
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...,)
->
T
:
...
@
overload
def
to
(
self
:
T
,
dtype
:
Union
[
dtype
,
str
],
non_blocking
:
bool
=
...)
->
T
:
...
@
overload
def
to
(
self
:
T
,
tensor
:
Tensor
,
non_blocking
:
bool
=
...)
->
T
:
...
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
(
device
is
not
None
and
device
.
type
==
"cuda"
and
self
.
data
.
device
.
type
==
"cpu"
):
return
self
.
cuda
(
device
)
else
:
new_param
=
FP4Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
quant_state
=
self
.
quant_state
)
return
new_param
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
FP4Params
(
self
.
weight
.
data
,
requires_grad
=
False
)
def
init_8bit_state
(
self
):
pass
def
forward
(
self
,
x
:
torch
.
Tensor
):
self
.
state
.
is_training
=
self
.
training
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
x
.
dtype
:
self
.
bias
.
data
=
self
.
bias
.
data
.
to
(
x
.
dtype
)
if
getattr
(
self
.
weight
,
'state'
,
None
)
is
None
:
print
(
'FP4 state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.'
)
out
=
bnb
.
matmul_fp
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
weight
.
state
)
return
out
class
Int8Params
(
torch
.
nn
.
Parameter
):
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
def
__new__
(
...
@@ -208,6 +269,7 @@ class Int8Params(torch.nn.Parameter):
...
@@ -208,6 +269,7 @@ class Int8Params(torch.nn.Parameter):
return
new_param
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
):
...
...
tests/test_autograd.py
View file @
160a8358
...
@@ -429,3 +429,118 @@ def test_matmullt(
...
@@ -429,3 +429,118 @@ def test_matmullt(
if
req_grad
[
2
]:
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul_fp4
)]
str_funcs
=
[
"matmul"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias"
,
values
,
ids
=
names
)
def
test_matmul_fp4
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
"cuda"
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
B
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
=
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
=
quant_state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
assert
err
<
0.11
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
tests/test_functional.py
View file @
160a8358
...
@@ -2221,26 +2221,13 @@ def test_fp4_quant():
...
@@ -2221,26 +2221,13 @@ def test_fp4_quant():
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
qa
,
SA
=
F
.
quantize_fp4
(
A1
,
blocksize
=
64
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
#qa, SA = F.quantize_fp4(A1, blocksize=128)
#A2 = F.dequantize_fp4(qa, SA, blocksize=128)
#A1 = A1.flatten().sort()[0]
#A2 = A2.flatten().sort()[0]
#print(A1)
#print(A2)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
err
=
err
.
mean
()
err
=
err
.
mean
()
print
(
err
,
relerr
)
assert
err
.
item
()
<
0.1
assert
relerr
.
item
()
<
0.28
#assert err.item() < 0.1
#assert relerr.item() < 0.28
def
test_bench_fp4_dequant
():
def
test_bench_fp4_dequant
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment