Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
439f2b0c
Unverified
Commit
439f2b0c
authored
Sep 19, 2022
by
Tim Dettmers
Committed by
GitHub
Sep 19, 2022
Browse files
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
parents
9b5f2eda
76ce9aa6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
100 additions
and
52 deletions
+100
-52
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+44
-32
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+15
-6
tests/test_autograd.py
tests/test_autograd.py
+6
-3
tests/test_modules.py
tests/test_modules.py
+35
-11
No files found.
bitsandbytes/autograd/_functions.py
View file @
439f2b0c
import
operator
import
operator
import
warnings
import
torch
import
torch
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
...
@@ -184,6 +186,7 @@ class MatmulLtState:
...
@@ -184,6 +186,7 @@ class MatmulLtState:
idx
=
None
idx
=
None
is_training
=
True
is_training
=
True
has_fp16_weights
=
True
has_fp16_weights
=
True
memory_efficient_backward
=
False
use_pool
=
False
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
...
@@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
B
=
B
ctx
.
B
=
B
ctx
.
bias
=
bias
ctx
.
bias
=
bias
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Quantize A
# 1. Quantize A
# 2. Quantize B
# 2. Quantize B
# 3. Matmul
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
# 5. Save state
requires_gradA
=
A
.
requires_grad
requires_gradB
=
B
.
requires_grad
requires_gradBias
=
bias
is
not
None
and
bias
.
requires_grad
formatB
=
state
.
formatB
formatB
=
state
.
formatB
input_shape
=
A
.
shape
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
(
A
.
dtype
==
torch
.
float16
# Cast A to fp16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
if
A
.
dtype
!=
torch
.
float16
:
warnings
.
warn
(
f
"MatMul8bitLt: inputs will be cast from
{
A
.
dtype
}
to float16 during quantization"
)
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
A
.
to
(
torch
.
float16
)
,
threshold
=
state
.
threshold
)
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
@@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
SCB
,
state
.
SCB
,
state
.
SCBt
,
state
.
SCBt
,
coo_tensorB
,
coo_tensorB
,
)
=
F
.
double_quant
(
B
)
)
=
F
.
double_quant
(
B
.
to
(
torch
.
float16
)
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
else
:
has_grad
=
False
has_grad
=
False
...
@@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
)
.
t
()
.
t
()
.
contiguous
()
.
contiguous
()
.
half
(
)
.
to
(
A
.
dtype
)
)
)
CA
[:,
state
.
idx
.
long
()]
=
0
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
...
@@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function):
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
# we apply the fused bias here
# we apply the fused bias here
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
output
.
to
(
A
.
dtype
)
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A
.
dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
...
@@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
formatB
=
formatB
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
,
requires_gradBias
]
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
requires_gradA
or
requires_gradB
:
if
any
(
ctx
.
needs_input_grad
[:
2
])
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
else
:
...
@@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
#clone_func = torch.clone
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
@
staticmethod
...
@@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req
_grad
s
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input
_grad
CAt
,
subA
=
ctx
.
tensors
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
state
=
ctx
.
state
state
=
ctx
.
state
assert
(
grad_A
=
grad_B
=
grad_bias
=
None
state
.
has_fp16_weights
),
"Backprop only supported for fp16 weights."
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
).
contiguous
()
grad_A
=
grad_B
=
grad_bias
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
))
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
...
@@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function):
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
req_gradA
:
if
state
.
CBt
is
not
None
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
:
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
.
to
(
ctx
.
dtype_A
)
if
req_gradBias
:
elif
state
.
CB
is
not
None
:
grad_bias
=
grad_output
.
sum
(
0
)
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
bitsandbytes/nn/modules.py
View file @
439f2b0c
...
@@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
...
@@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
output_features
,
output_features
,
bias
=
True
,
bias
=
True
,
has_fp16_weights
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
,
threshold
=
0.0
,
index
=
None
,
index
=
None
,
):
):
...
@@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
...
@@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
self
.
state
.
threshold
=
threshold
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
,
requires_grad
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
CB
=
self
.
weight
.
CB
...
@@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
...
@@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
if
not
self
.
state
.
has_fp16_weights
:
if
not
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
# we no longer need the row-major weight
del
self
.
state
.
CB
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
self
.
weight
.
data
=
self
.
state
.
CxB
elif
self
.
state
.
memory_efficient_backward
and
self
.
state
.
CxB
is
not
None
:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del
self
.
state
.
CxB
return
out
return
out
...
...
tests/test_autograd.py
View file @
439f2b0c
...
@@ -253,7 +253,7 @@ for c in req_grad:
...
@@ -253,7 +253,7 @@ for c in req_grad:
transpose
=
[(
False
,
True
),
(
False
,
False
)]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
]
dtype
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
values
=
list
(
...
@@ -354,7 +354,7 @@ def test_matmullt(
...
@@ -354,7 +354,7 @@ def test_matmullt(
state
.
SCB
,
state
.
SCB
,
SCBt
,
SCBt
,
coo_tensorB
,
coo_tensorB
,
)
=
bnb
.
functional
.
double_quant
(
B2
)
)
=
bnb
.
functional
.
double_quant
(
B2
.
to
(
torch
.
float16
)
)
B2
=
state
.
CB
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
...
@@ -367,11 +367,14 @@ def test_matmullt(
...
@@ -367,11 +367,14 @@ def test_matmullt(
if
has_bias
:
if
has_bias
:
out_torch
+=
bias
out_torch
+=
bias
assert
out_bnb
.
dtype
==
A
.
dtype
,
f
"bnb matmullt received
{
A
.
dtype
}
but returned
{
out_bnb
.
dtype
}
"
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
# print(f'abs error {err:.4f}')
# print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.0175
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
(
0.0175
if
dtype
==
torch
.
float16
else
0.021
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.001
assert
(
idx
==
0
).
sum
().
item
()
<=
n
*
0.001
...
...
tests/test_modules.py
View file @
439f2b0c
...
@@ -14,13 +14,15 @@ class MockArgs(object):
...
@@ -14,13 +14,15 @@ class MockArgs(object):
class
MLP8bit
(
torch
.
nn
.
Module
):
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
threshold
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
...
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
@
pytest
.
mark
.
parametrize
(
"memory_efficient_backward"
,
[
True
,
False
])
def
test_linear8bitlt_no_fp16_weights
(
threshold
,
memory_efficient_backward
):
l1
=
(
l1
=
(
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
cuda
()
.
cuda
()
.
half
()
.
half
()
)
)
...
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
(
mlp
=
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
)
.
half
()
.
half
()
.
to
(
"cuda"
)
.
to
(
"cuda"
)
)
)
...
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
mlp
=
(
mlp
=
MLP8bit
(
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
)
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
memory_efficient_backward
.
to
(
torch
.
float16
)
.
to
(
"cuda"
)
)
)
w1
,
w2
=
mlp
.
fc1
.
weight
.
clone
().
cuda
(),
mlp
.
fc2
.
weight
.
clone
().
cuda
()
# grab weights before quantization,
mlp
=
mlp
.
cuda
().
half
()
# and this line triggers quantization
for
i
in
range
(
100
):
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
...
@@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
...
@@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert
mlp
.
fc1
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc1
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
assert
mlp
.
fc2
.
weight
.
device
.
type
==
"cuda"
if
memory_efficient_backward
:
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
,
requires_grad
=
True
,
dtype
=
torch
.
half
)
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
assert
o1
.
requires_grad
grad_proj
=
torch
.
randn_like
(
o1
)
mlp
.
zero_grad
()
(
o1
*
grad_proj
).
sum
().
backward
()
grad_ref
=
grad_proj
.
flatten
(
2
)
@
w2
.
half
()
@
w1
.
half
()
scale
=
grad_ref
.
abs
().
mean
()
torch
.
testing
.
assert_allclose
(
b1
.
grad
,
grad_ref
,
rtol
=
0
,
atol
=
0.05
*
scale
)
idx
=
torch
.
isclose
(
b1
.
grad
,
grad_ref
,
atol
=
0.01
*
scale
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<=
b1
.
numel
()
*
0.005
def
test_linear8bitlt_fp32_bias
():
def
test_linear8bitlt_fp32_bias
():
# casts model to fp16 -> int8 automatically
# casts model to fp16 -> int8 automatically
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment