Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
42b5fc9a
Commit
42b5fc9a
authored
Sep 11, 2022
by
dbaranchuk
Browse files
add memory effcient backward option
parent
843ad063
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
10 deletions
+52
-10
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+40
-6
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+12
-4
No files found.
bitsandbytes/autograd/_functions.py
View file @
42b5fc9a
import
operator
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
...
...
@@ -187,6 +188,8 @@ class MatmulLtState:
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
memory_efficient_backward
=
False
def
reset_grads
(
self
):
self
.
CB
=
None
self
.
CxB
=
None
...
...
@@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
...
...
@@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function):
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
assert
not
req_gradB
,
"TODO: support weight updates as well"
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
# Cast grad_output to fp16
...
...
@@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A
=
grad_B
=
grad_bias
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
state
.
CBt
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
elif
state
.
CB
:
CB
=
state
.
CB
.
half
()
SCB
=
(
state
.
SCB
.
unsqueeze
(
1
)
/
127.0
).
half
()
CB
*=
SCB
grad_A
=
torch
.
mm
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix'
)
if
req_gradBias
:
grad_bias
=
grad_output
.
sum
(
0
)
...
...
@@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
...
...
bitsandbytes/nn/modules.py
View file @
42b5fc9a
...
...
@@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear):
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
,
memory_efficient_backward
=
False
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
...
...
@@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear):
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
...
...
@@ -255,9 +257,15 @@ class Linear8bitLt(nn.Linear):
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CxB
is
not
None
:
# In this version, we convert 8-bit row major to turing/ampere format at each inference pass
# Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
elif
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CxB
is
not
None
:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del
self
.
state
.
CxB
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment