Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
a825f996
Unverified
Commit
a825f996
authored
Sep 18, 2020
by
Shaden Smith
Committed by
GitHub
Sep 18, 2020
Browse files
Fix activation checkpoint unit tests for GPU systems (#421)
parent
a74a604a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
170 additions
and
12 deletions
+170
-12
deepspeed/runtime/activation_checkpointing/checkpointing.py
deepspeed/runtime/activation_checkpointing/checkpointing.py
+13
-12
tests/unit/test_activation_checkpointing.py
tests/unit/test_activation_checkpointing.py
+157
-0
No files found.
deepspeed/runtime/activation_checkpointing/checkpointing.py
100755 → 100644
View file @
a825f996
...
@@ -481,13 +481,16 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -481,13 +481,16 @@ class CheckpointFunction(torch.autograd.Function):
if
SYNCHRONIZE
:
if
SYNCHRONIZE
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Tensors returned from forward() may not be differentiable, e.g., attention mask
# Tensors returned from forward() may not be differentiable.
non_grad_outputs
=
[
o
for
o
in
outputs
if
not
o
.
is_floating_point
()]
if
torch
.
is_tensor
(
outputs
):
non_grad_outputs
=
[
outputs
]
if
not
outputs
.
is_floating_point
()
else
[]
else
:
non_grad_outputs
=
[
o
for
o
in
outputs
if
not
o
.
is_floating_point
()]
ctx
.
mark_non_differentiable
(
*
non_grad_outputs
)
ctx
.
mark_non_differentiable
(
*
non_grad_outputs
)
return
outputs
return
outputs
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
arg
s
):
def
backward
(
ctx
,
*
grad
s
):
global
timers
global
timers
#see_memory_usage("In backward", force=True)
#see_memory_usage("In backward", force=True)
#removing pointers to the contiguous buffer memory
#removing pointers to the contiguous buffer memory
...
@@ -553,17 +556,15 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -553,17 +556,15 @@ class CheckpointFunction(torch.autograd.Function):
if
isinstance
(
outputs
,
torch
.
Tensor
):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,
)
outputs
=
(
outputs
,
)
# Go over args and build the list of gradient tensors. This is usually just args,
# Construct arguments to autograd.backward().
# but if the forward pass returns tensors that do not require_grad then we should
# This is usually just outputs and grads, but forward() can return tensors that
# adjust the arguments to autograd.backward() too. This happens when forward()
# are not differentiable.
# returns indices or a mask (such as an attention mask).
# We skip the first needs_input_grad because it corresponds to run_function.
output_tensors
=
[]
output_tensors
=
[]
grad_tensors
=
[]
grad_tensors
=
[]
for
idx
,
need_grad
in
enumerate
(
ctx
.
needs_input_grad
[
1
:]
):
for
out
,
grad
in
zip
(
outputs
,
grads
):
if
need
_grad
:
if
out
.
requires
_grad
:
output_tensors
.
append
(
out
puts
[
idx
]
)
output_tensors
.
append
(
out
)
grad_tensors
.
append
(
args
[
idx
]
)
grad_tensors
.
append
(
grad
)
torch
.
autograd
.
backward
(
output_tensors
,
grad_tensors
)
torch
.
autograd
.
backward
(
output_tensors
,
grad_tensors
)
...
...
tests/unit/test_activation_checkpointing.py
0 → 100644
View file @
a825f996
# TODO: add tests with model parallelism for activation partitioning and other features.
from
copy
import
deepcopy
import
pytest
import
torch
import
deepspeed
ckpt
=
deepspeed
.
checkpointing
.
checkpoint
from
common
import
distributed_test
def
_compute
(
module
,
*
inputs
,
do_checkpoint
=
False
):
if
do_checkpoint
:
outputs
=
ckpt
(
module
,
*
inputs
)
else
:
outputs
=
module
(
*
inputs
)
if
torch
.
is_tensor
(
outputs
):
outputs
=
(
outputs
,
)
sum
(
o
.
sum
()
for
o
in
outputs
if
o
.
requires_grad
).
backward
()
grads
=
[
p
.
grad
for
p
in
module
.
parameters
()]
input_grads
=
[
inp
.
grad
for
inp
in
inputs
]
return
{
'outputs'
:
outputs
,
'module_grads'
:
grads
,
'input_grads'
:
input_grads
,
}
# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@
distributed_test
(
world_size
=
1
)
def
_test_activation_checkpoint
(
module
,
*
inputs
):
# Move to device
module
.
cuda
()
# Get rid of dropouts until we fork the RNG between tests.
module
.
eval
()
module_
=
deepcopy
(
module
)
inputs_
=
tuple
(
deepcopy
(
inp
).
cuda
()
for
inp
in
inputs
)
base
=
_compute
(
module_
,
*
inputs_
,
do_checkpoint
=
False
)
module_
=
deepcopy
(
module
)
inputs_
=
tuple
(
deepcopy
(
inp
).
cuda
()
for
inp
in
inputs
)
test
=
_compute
(
module_
,
*
inputs_
,
do_checkpoint
=
True
)
for
group
in
base
.
keys
():
for
b
,
t
in
zip
(
base
[
group
],
test
[
group
]):
# Catch grad `None`s, etc.
if
not
torch
.
is_tensor
(
b
):
assert
b
==
t
elif
b
.
is_floating_point
():
assert
torch
.
allclose
(
b
,
t
)
else
:
assert
torch
.
equal
(
b
,
t
)
#
# Helpers
#
class
MaskedLinear
(
torch
.
nn
.
Linear
):
def
forward
(
self
,
x
,
mask
):
out
=
super
().
forward
(
x
)
if
mask
.
is_floating_point
():
out
=
out
*
mask
else
:
# must cast BoolTensor in older torch versions
out
=
out
*
mask
.
type_as
(
out
)
return
out
class
MaskedLinearSeq
(
MaskedLinear
):
"""Tests pipeline modules by also returning the mask."""
def
forward
(
self
,
x
,
mask
):
return
super
().
forward
(
x
,
mask
),
mask
class
MaskedLinearSeqDup
(
MaskedLinearSeq
):
"""MaskedLinearSeq, but with more outputs than inputs and in a different order."""
def
forward
(
self
,
x
,
mask
):
dup
=
x
.
clone
().
detach
()
*
1.38
# just an arbitrary scaling
x
,
mask
=
super
().
forward
(
x
,
mask
)
return
dup
,
x
,
mask
HIDDEN_DIM
=
20
def
_mixed_mask
(
size
=
HIDDEN_DIM
):
entries
=
torch
.
randn
(
size
)
mask
=
torch
.
where
(
entries
>
0
,
torch
.
ones
(
size
),
torch
.
zeros
(
size
))
mask
=
mask
.
bool
()
return
mask
def
_bool_to_float
(
btensor
,
dtype
=
torch
.
float32
):
"""Converts a torch.BoolTensor to an equivalent dtype. """
ones
=
torch
.
ones
(
size
=
btensor
.
size
(),
dtype
=
dtype
)
zeros
=
torch
.
zeros
(
size
=
btensor
.
size
(),
dtype
=
dtype
)
return
torch
.
where
(
btensor
,
ones
,
zeros
)
#
# Tests
#
def
test_ckpt_inputs1_outputs1
():
module
=
torch
.
nn
.
Linear
(
HIDDEN_DIM
,
HIDDEN_DIM
)
inputs
=
torch
.
rand
(
HIDDEN_DIM
)
inputs
.
requires_grad
=
True
_test_activation_checkpoint
(
module
,
inputs
)
# both bool and float are important, as bool is not diffentiable
@
pytest
.
mark
.
parametrize
(
'mask'
,
[
_mixed_mask
(),
_bool_to_float
(
_mixed_mask
()),
])
def
test_ckpt_inputs2_outputs1
(
mask
):
module
=
MaskedLinear
(
HIDDEN_DIM
,
HIDDEN_DIM
)
inputs
=
torch
.
rand
(
HIDDEN_DIM
)
inputs
.
requires_grad
=
True
_test_activation_checkpoint
(
module
,
inputs
,
mask
)
@
pytest
.
mark
.
parametrize
(
'mask'
,
[
_mixed_mask
(),
_bool_to_float
(
_mixed_mask
()),
])
def
test_ckpt_inputs2_outputs2
(
mask
):
module
=
MaskedLinearSeq
(
HIDDEN_DIM
,
HIDDEN_DIM
)
inputs
=
torch
.
rand
(
HIDDEN_DIM
)
inputs
.
requires_grad
=
True
_test_activation_checkpoint
(
module
,
inputs
,
mask
)
@
pytest
.
mark
.
parametrize
(
'mask'
,
[
_mixed_mask
(),
_bool_to_float
(
_mixed_mask
()),
])
def
test_ckpt_inputs2_outputs3
(
mask
):
module
=
MaskedLinearSeqDup
(
HIDDEN_DIM
,
HIDDEN_DIM
)
inputs
=
torch
.
rand
(
HIDDEN_DIM
)
inputs
.
requires_grad
=
True
_test_activation_checkpoint
(
module
,
inputs
,
mask
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment