Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4d6174bc
Commit
4d6174bc
authored
Aug 25, 2022
by
dbaranchuk
Browse files
memory efficient fp16 backward
parent
ef2936a9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
41 deletions
+6
-41
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+5
-35
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-6
No files found.
bitsandbytes/autograd/_functions.py
View file @
4d6174bc
...
@@ -196,7 +196,6 @@ class MatmulLtState:
...
@@ -196,7 +196,6 @@ class MatmulLtState:
self
.
CxBt
=
None
self
.
CxBt
=
None
self
.
SBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
...
@@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function):
#clone_func = torch.clone
#clone_func = torch.clone
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
CAt
,
subA
=
ctx
.
tensors
assert
not
req_gradB
,
"TODO: support weight updates as well"
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
...
@@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A
=
grad_B
=
grad_bias
=
None
grad_A
=
grad_B
=
grad_bias
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
CB
=
state
.
CB
.
half
()
if
state
.
CxBt
is
None
:
SCB
=
state
.
SCB
.
unsqueeze
(
1
).
half
()
if
state
.
has_fp16_weights
:
B
=
(
CB
*
SCB
)
/
127.0
CBt
=
state
.
CBt
grad_A
=
torch
.
mm
(
grad_output
,
B
).
view
(
ctx
.
grad_shape
)
else
:
# Restore CBt from CB
assert
state
.
CBt
is
None
,
"CBt should not be stored in state"
CB
=
state
.
CB
.
half
()
SCB
=
state
.
SCB
.
unsqueeze
(
1
).
half
()
SCBt
=
state
.
SCBt
.
unsqueeze
(
1
).
half
()
Bt
=
(
CB
*
SCB
).
t
().
contiguous
()
CBt
=
(
Bt
/
SCBt
).
t
().
to
(
torch
.
int8
)
# intentionally, do not store CxBt in state
CxBt
,
SBt
=
F
.
transform
(
CBt
,
to_order
=
formatB
,
transpose
=
True
)
else
:
CxBt
=
state
.
CxBt
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
CxBt
,
Sgrad
,
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
if
req_gradBias
:
if
req_gradBias
:
grad_bias
=
grad_output
.
sum
(
0
)
grad_bias
=
grad_output
.
sum
(
0
)
...
...
bitsandbytes/nn/modules.py
View file @
4d6174bc
...
@@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter):
...
@@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights
=
False
,
has_fp16_weights
=
False
,
CB
=
None
,
CB
=
None
,
SCB
=
None
,
SCB
=
None
,
SCBt
=
None
,
):
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
CB
=
None
cls
.
SCB
=
None
cls
.
SCB
=
None
cls
.
SCBt
=
None
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
...
@@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter):
...
@@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter):
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
del
CBt
del
CBt
del
SCBt
self
.
data
=
CB
self
.
data
=
CB
setattr
(
self
,
"CB"
,
CB
)
setattr
(
self
,
"CB"
,
CB
)
setattr
(
self
,
"SCB"
,
SCB
)
setattr
(
self
,
"SCB"
,
SCB
)
setattr
(
self
,
"SCBt"
,
SCBt
)
return
self
return
self
...
@@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter):
...
@@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter):
)
)
new_param
.
CB
=
self
.
CB
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCB
=
self
.
SCB
new_param
.
SCBt
=
self
.
SCBt
return
new_param
return
new_param
...
@@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear):
...
@@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear):
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
state
.
SCBt
=
self
.
weight
.
SCBt
self
.
weight
.
CB
=
None
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
self
.
weight
.
SCB
=
None
self
.
weight
.
SCBt
=
None
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
self
.
state
.
is_training
=
self
.
training
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment