Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
439f2b0c
Unverified
Commit
439f2b0c
authored
Sep 19, 2022
by
Tim Dettmers
Committed by
GitHub
Sep 19, 2022
Browse files
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
parents
9b5f2eda
76ce9aa6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
52 deletions
+100
-52
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+44
-32
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+15
-6
tests/test_autograd.py
tests/test_autograd.py
+6
-3
tests/test_modules.py
tests/test_modules.py
+35
-11
No files found.
bitsandbytes/autograd/_functions.py
View file @
439f2b0c
import
operator
import
warnings
import
torch
import
bitsandbytes.functional
as
F
...
...
@@ -184,6 +186,7 @@ class MatmulLtState:
idx
=
None
is_training
=
True
has_fp16_weights
=
True
memory_efficient_backward
=
False
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
...
...
@@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
B
=
B
ctx
.
bias
=
bias
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
requires_gradA
=
A
.
requires_grad
requires_gradB
=
B
.
requires_grad
requires_gradBias
=
bias
is
not
None
and
bias
.
requires_grad
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# Cast A to fp16
if
A
.
dtype
!=
torch
.
float16
:
warnings
.
warn
(
f
"MatMul8bitLt: inputs will be cast from
{
A
.
dtype
}
to float16 during quantization"
)
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
A
.
to
(
torch
.
float16
)
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
...
@@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
)
)
=
F
.
double_quant
(
B
.
to
(
torch
.
float16
)
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
...
...
@@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
contiguous
()
.
half
(
)
.
to
(
A
.
dtype
)
)
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
...
...
@@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function):
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
# we apply the fused bias here
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
output
.
to
(
A
.
dtype
)
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A
.
dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
...
...
@@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
,
requires_gradBias
]
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
requires_gradA
or
requires_gradB
:
if
any
(
ctx
.
needs_input_grad
[:
2
])
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
...
...
@@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
#clone_func = torch.clone
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
...
...
@@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function):
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req
_grad
s
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input
_grad
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
(
state
.
has_fp16_weights
),
"Backprop only supported for fp16 weights."
grad_A
=
grad_B
=
grad_bias
=
None
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
grad_A
=
grad_B
=
grad_bias
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
...
...
@@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function):
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
state
.
CBt
is
not
None
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
.
to
(
ctx
.
dtype_A
)
if
req_gradBias
:
grad_bias
=
grad_output
.
sum
(
0
)
elif
state
.
CB
is
not
None
:
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
bitsandbytes/nn/modules.py
View file @
439f2b0c
...
...
@@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
index
=
None
,
):
...
...
@@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
...
...
@@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
elif
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CxB
is
not
None
:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del
self
.
state
.
CxB
return
out
...
...
tests/test_autograd.py
View file @
439f2b0c
...
...
@@ -253,7 +253,7 @@ for c in req_grad:
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
...
...
@@ -354,7 +354,7 @@ def test_matmullt(
state
.
SCB
,
SCBt
,
coo_tensorB
,
)
=
bnb
.
functional
.
double_quant
(
B2
)
)
=
bnb
.
functional
.
double_quant
(
B2
.
to
(
torch
.
float16
)
)
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
...
...
@@ -367,11 +367,14 @@ def test_matmullt(
if
has_bias
:
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
# print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
(
0.0175
if
dtype
==
torch
.
float16
else
0.021
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.001
...
...
tests/test_modules.py
View file @
439f2b0c
...
...
@@ -14,13 +14,15 @@ class MockArgs(object):
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
def
forward
(
self
,
x
):
...
...
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
True
,
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
half
()
)
...
...
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
to
(
"cuda"
)
)
...
...
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
.
to
(
torch
.
float16
)
.
to
(
"cuda"
)
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
...
@@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
if
memory_efficient_backward
:
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
,
requires_grad
=
True
,
dtype
=
torch
.
half
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
requires_grad
grad_proj
=
torch
.
randn_like
(
o1
)
mlp
.
zero_grad
()
(
o1
*
grad_proj
).
sum
().
backward
()
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
scale
=
grad_ref
.
abs
().
mean
()
torch
.
testing
.
assert_allclose
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
def
test_linear8bitlt_fp32_bias
():
# casts model to fp16 -> int8 automatically
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment