Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
8ae9bb23
Commit
8ae9bb23
authored
Aug 23, 2022
by
dbaranchuk
Browse files
add memory efficient backward
parent
9d60b3c5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
24 deletions
+28
-24
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+19
-20
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+9
-4
No files found.
bitsandbytes/autograd/_functions.py
View file @
8ae9bb23
...
@@ -245,8 +245,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -245,8 +245,7 @@ class MatMul8bitLt(torch.autograd.Function):
subA
=
A
[:,
idx
]
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
state
.
idx
=
idx
else
:
elif
state
.
CxB
is
None
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
...
@@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
idx
=
outlier_idx
state
.
idx
=
outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
state
.
subB
=
(
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
...
@@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
assert
(
state
.
has_fp16_weights
),
"Backprop only supported for fp16 weights."
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
).
contiguous
()
...
@@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function):
if
req_gradA
:
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
if
state
.
CxBt
is
None
and
state
.
has_fp16_weights
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
CBt
=
state
.
CBt
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
elif
state
.
CxBt
is
None
:
assert
state
.
CBt
is
None
CB
=
state
.
CB
.
half
()
SCB
=
state
.
SCB
.
unsquezee
(
1
).
half
()
SCBt
=
state
.
SCBt
.
unsquezee
(
1
).
half
()
Bt
=
(
CB
*
SCB
).
t
().
contiguous
()
CBt
=
(
Bt
/
SCBt
).
t
().
to
(
torch
.
int8
)
CxBt
,
SBt
=
F
.
transform
(
CBt
,
to_order
=
formatB
,
transpose
=
True
)
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
CxBt
,
Sgrad
,
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
if
req_gradBias
:
if
req_gradBias
:
...
...
bitsandbytes/nn/modules.py
View file @
8ae9bb23
...
@@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
...
@@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights
=
False
,
has_fp16_weights
=
False
,
CB
=
None
,
CB
=
None
,
SCB
=
None
,
SCB
=
None
,
SCBt
=
None
,
):
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
CB
=
None
cls
.
SCB
=
None
cls
.
SCB
=
None
cls
.
SCBt
=
None
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
...
@@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
...
@@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
del
CBt
del
CBt
del
SCBt
self
.
data
=
CB
self
.
data
=
CB
setattr
(
self
,
"CB"
,
CB
)
setattr
(
self
,
"CB"
,
CB
)
setattr
(
self
,
"SCB"
,
SCB
)
setattr
(
self
,
"SCB"
,
SCB
)
setattr
(
self
,
"SCBt"
,
SCBt
)
return
self
return
self
...
@@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
...
@@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
)
)
new_param
.
CB
=
self
.
CB
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCB
=
self
.
SCBt
return
new_param
return
new_param
...
@@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
...
@@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCBt
=
self
.
weight
.
SCBt
self
.
weight
.
CB
=
None
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
self
.
weight
.
SCB
=
None
self
.
weight
.
SCBt
=
None
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
self
.
state
.
is_training
=
self
.
training
...
@@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
...
@@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
#
if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
# we no longer need the row-major weight
del
self
.
state
.
CB
#
del self.state.CB
self
.
weight
.
data
=
self
.
state
.
CxB
#
self.weight.data = self.state.CxB
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment