Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
591f6039
Commit
591f6039
authored
Sep 18, 2022
by
justheuristic
Browse files
add memory efficient backward
parent
579b8c78
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+0
-1
tests/test_modules.py
tests/test_modules.py
+17
-7
No files found.
bitsandbytes/autograd/_functions.py
View file @
591f6039
...
@@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
elif
state
.
CB
is
not
None
:
elif
state
.
CB
is
not
None
:
raise
NotImplementedError
(
"WIP"
)
CB
=
state
.
CB
.
to
(
ctx
.
dtype_B
)
CB
=
state
.
CB
.
to
(
ctx
.
dtype_B
)
CB
.
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
div_
(
127.0
).
to
(
CB
.
dtype
))
CB
.
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
div_
(
127.0
).
to
(
CB
.
dtype
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
...
...
tests/test_modules.py
View file @
591f6039
...
@@ -14,13 +14,15 @@ class MockArgs(object):
...
@@ -14,13 +14,15 @@ class MockArgs(object):
class
MLP8bit
(
torch
.
nn
.
Module
):
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
...
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
True
,
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
cuda
()
.
half
()
.
half
()
)
)
...
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
half
()
.
to
(
"cuda"
)
.
to
(
"cuda"
)
)
)
...
@@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
mlp
=
(
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
to
(
torch
.
float16
)
.
to
(
torch
.
float16
)
.
to
(
"cuda"
)
.
to
(
"cuda"
)
)
)
...
@@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
def
test_linear8bitlt_fp32_bias
():
def
test_linear8bitlt_fp32_bias
():
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
has_fp16_weights
=
False
).
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment