Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
de354f7d
Commit
de354f7d
authored
Aug 16, 2022
by
Tim Dettmers
Browse files
Added fused bias to matmullt.
parent
dede3430
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
53 deletions
+62
-53
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+20
-35
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+4
-7
tests/test_autograd.py
tests/test_autograd.py
+38
-11
No files found.
bitsandbytes/autograd/_functions.py
View file @
de354f7d
...
@@ -201,13 +201,14 @@ class MatmulLtState:
...
@@ -201,13 +201,14 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
# default to pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
A
=
A
ctx
.
B
=
B
ctx
.
B
=
B
ctx
.
bias
=
bias
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
else
:
else
:
...
@@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function):
# 5. Save state
# 5. Save state
requires_gradA
=
A
.
requires_grad
requires_gradA
=
A
.
requires_grad
requires_gradB
=
B
.
requires_grad
requires_gradB
=
B
.
requires_grad
requires_gradBias
=
bias
is
not
None
and
bias
.
requires_grad
formatB
=
state
.
formatB
formatB
=
state
.
formatB
input_shape
=
A
.
shape
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
if
state
.
outlier_pool
is
None
:
...
@@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function):
if
state
.
CxB
is
None
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CB
,
to_order
=
formatB
)
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
# if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
# else:
# subA = None
else
:
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
...
@@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function):
# 3. Matmul
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
# we apply the fused bias here
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
...
@@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
formatB
=
formatB
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
]
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
,
requires_gradBias
]
if
requires_gradA
or
requires_gradB
:
if
requires_gradA
or
requires_gradB
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensors
=
(
CAt
,
subA
)
...
@@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
clone_func
=
torch
.
clone
#
clone_func = torch.clone
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
req_gradA
,
req_gradB
=
ctx
.
req_grads
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
CAt
,
subA
=
ctx
.
tensors
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
...
@@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function):
-
1
,
grad_output
.
shape
[
-
1
]
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
).
contiguous
()
grad_A
=
grad_B
=
None
grad_A
=
grad_B
=
grad_bias
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
if
req_gradB
:
...
@@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
ctx
.
grad_shape
)
if
req_gradBias
:
grad_bias
=
grad_output
.
sum
(
0
)
return
grad_A
,
grad_B
,
None
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
matmul
=
MatMul8bitLt
.
apply
matmul
=
MatMul8bitLt
.
apply
...
@@ -403,8 +387,9 @@ def matmul(
...
@@ -403,8 +387,9 @@ def matmul(
out
:
tensor
=
None
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
,
threshold
=
0.0
,
bias
=
None
):
):
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
bitsandbytes/nn/modules.py
View file @
de354f7d
...
@@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear):
...
@@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear):
if
threshold
>
0.0
and
not
has_fp16_weights
:
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
...
@@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear):
...
@@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear):
if
self
.
weight
.
CB
is
not
None
:
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
self
.
init_8bit_state
()
if
self
.
bias
.
dtype
!=
torch
.
float16
:
self
.
bias
.
data
=
self
.
bias
.
data
.
half
()
# assert not self.state.has_fp16_weights
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
unsqueeze
(
0
).
expand_as
(
out
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we converted 8-bit row major to turing/ampere format in the first inference pass
...
...
tests/test_autograd.py
View file @
de354f7d
from
itertools
import
product
from
itertools
import
product
,
permutations
import
pytest
import
pytest
import
torch
import
torch
...
@@ -241,11 +241,20 @@ decomp = [0.0, 6.0]
...
@@ -241,11 +241,20 @@ decomp = [0.0, 6.0]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
str_funcs
=
[
"matmul"
]
str_funcs
=
[
"matmul"
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
"FF"
,
"TF"
,
"TT"
,
"FT"
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
for
c
in
req_grad
:
strval
=
''
for
v
in
c
:
if
v
==
True
:
strval
+=
'T'
else
:
strval
+=
'F'
req_grad_str
.
append
(
strval
)
transpose
=
[(
False
,
True
),
(
False
,
False
)]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
values
=
list
(
product
(
product
(
dim1
,
dim1
,
...
@@ -258,6 +267,7 @@ values = list(
...
@@ -258,6 +267,7 @@ values = list(
transpose
,
transpose
,
decomp
,
decomp
,
has_fp16_weights
,
has_fp16_weights
,
has_bias
)
)
)
)
str_values
=
list
(
str_values
=
list
(
...
@@ -272,18 +282,14 @@ str_values = list(
...
@@ -272,18 +282,14 @@ str_values = list(
str_transpose
,
str_transpose
,
decomp
,
decomp
,
has_fp16_weights
,
has_fp16_weights
,
has_bias
)
)
)
)
names
=
[
names
=
[
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}"
.
format
(
*
vals
)
for
vals
in
str_values
]
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
, has_bias
"
,
values
,
values
,
ids
=
names
,
ids
=
names
,
)
)
...
@@ -298,10 +304,14 @@ def test_matmullt(
...
@@ -298,10 +304,14 @@ def test_matmullt(
transpose
,
transpose
,
decomp
,
decomp
,
has_fp16_weights
,
has_fp16_weights
,
has_bias
):
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
"cuda"
)
if
has_bias
==
False
:
req_grad
=
list
(
req_grad
)
req_grad
[
2
]
=
False
for
i
in
range
(
k
):
for
i
in
range
(
k
):
...
@@ -322,6 +332,11 @@ def test_matmullt(
...
@@ -322,6 +332,11 @@ def test_matmullt(
requires_grad
=
req_grad
[
1
],
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
,
dtype
=
dtype
,
)
)
bias
=
None
bias2
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
req_grad
[
2
])
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
B2
=
B
.
clone
()
...
@@ -342,10 +357,13 @@ def test_matmullt(
...
@@ -342,10 +357,13 @@ def test_matmullt(
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
state
=
state
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
state
=
state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
,
bias
=
bias2
)
if
has_bias
:
out_torch
+=
bias
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
...
@@ -367,6 +385,9 @@ def test_matmullt(
...
@@ -367,6 +385,9 @@ def test_matmullt(
gradB1
=
B
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias1
=
bias
.
grad
bias
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
out_torch
,
target
...
@@ -376,6 +397,9 @@ def test_matmullt(
...
@@ -376,6 +397,9 @@ def test_matmullt(
gradB2
=
B
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
A
.
grad
=
None
B
.
grad
=
None
B
.
grad
=
None
if
has_bias
:
gradBias2
=
bias
.
grad
bias
.
grad
=
None
if
req_grad
[
0
]:
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
...
@@ -397,3 +421,6 @@ def test_matmullt(
...
@@ -397,3 +421,6 @@ def test_matmullt(
torch
.
testing
.
assert_allclose
(
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
if
req_grad
[
2
]:
torch
.
testing
.
assert_allclose
(
gradBias1
,
gradBias2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment