Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1e2557e8
Unverified
Commit
1e2557e8
authored
Apr 02, 2022
by
LuGY
Committed by
GitHub
Apr 02, 2022
Browse files
[zero] fixed the activation offload (#647)
* fixed the activation offload * polish
parent
828e4656
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
13 deletions
+28
-13
colossalai/utils/activation_checkpoint.py
colossalai/utils/activation_checkpoint.py
+28
-13
No files found.
colossalai/utils/activation_checkpoint.py
View file @
1e2557e8
...
...
@@ -7,6 +7,19 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable
from
colossalai.context.random
import
get_states
,
get_current_mode
,
set_seed_states
,
set_mode
,
sync_states
from
.cuda
import
get_current_device
def
copy_to_device
(
obj
,
device
):
if
torch
.
is_tensor
(
obj
):
return
obj
.
to
(
device
)
elif
isinstance
(
obj
,
list
):
return
[
copy_to_device
(
i
,
device
)
for
i
in
obj
]
elif
isinstance
(
obj
,
tuple
):
return
tuple
([
copy_to_device
(
v
,
device
)
for
v
in
obj
])
elif
isinstance
(
obj
,
dict
):
return
{
k
:
copy_to_device
(
v
,
device
)
for
k
,
v
in
obj
.
items
()}
else
:
return
obj
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -26,7 +39,14 @@ class CheckpointFunction(torch.autograd.Function):
ctx
.
had_autocast_in_fwd
=
torch
.
is_autocast_enabled
()
else
:
ctx
.
had_autocast_in_fwd
=
False
if
activation_offload
:
inputs_cuda
=
copy_to_device
(
args
,
ctx
.
device
)
else
:
inputs_cuda
=
args
with
torch
.
no_grad
():
outputs
=
run_function
(
*
inputs_cuda
)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx
.
inputs
=
[]
...
...
@@ -34,10 +54,8 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
torch
.
is_tensor
(
arg
):
if
ctx
.
activation_offload
:
tmp
=
arg
.
detach
().
cpu
()
tmp
.
requires_grad
=
arg
.
requires_grad
tensor_inputs
.
append
(
tmp
)
if
activation_offload
:
tensor_inputs
.
append
(
copy_to_device
(
arg
,
'cpu'
))
else
:
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
...
...
@@ -46,18 +64,15 @@ class CheckpointFunction(torch.autograd.Function):
ctx
.
inputs
.
append
(
arg
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
with
torch
.
no_grad
():
outputs
=
run_function
(
*
args
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
if
not
torch
.
autograd
.
_is_checkpoint_valid
():
raise
RuntimeError
(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
"
is
passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument."
)
"Checkpointing is not compatible with .grad() or when an `inputs` parameter
is
"
"passed to .backward(). Please use .backward() and do not pass its `inputs`
argument.
"
)
# Copy the list to avoid modifying original list.
inputs
=
list
(
ctx
.
inputs
)
tensor_indices
=
ctx
.
tensor_indices
...
...
@@ -74,12 +89,12 @@ class CheckpointFunction(torch.autograd.Function):
for
parallel_mode
,
state
in
ctx
.
fwd_seed_states
.
items
():
set_seed_states
(
parallel_mode
,
state
)
set_mode
(
ctx
.
fwd_current_mode
)
if
ctx
.
activation_offload
:
tensors
=
copy_to_device
(
tensors
,
ctx
.
device
)
# Fill in inputs with appropriate saved tensors.
for
i
,
idx
in
enumerate
(
tensor_indices
):
tmp
=
tensors
[
i
].
detach
().
to
(
ctx
.
device
)
tmp
.
requires_grad
=
tensors
[
i
].
requires_grad
inputs
[
idx
]
=
tmp
inputs
[
idx
]
=
tensors
[
i
]
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
if
ctx
.
had_autocast_in_fwd
:
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment