Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
de464504
Commit
de464504
authored
Mar 11, 2022
by
LuGY
Committed by
Frank Lee
Mar 11, 2022
Browse files
Added activation offload (#331)
* Added activation offload * Fixed the import bug, used the pytest
parent
272ebfb5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
22 deletions
+28
-22
colossalai/nn/layer/utils/common.py
colossalai/nn/layer/utils/common.py
+3
-2
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+3
-2
colossalai/utils/activation_checkpoint.py
colossalai/utils/activation_checkpoint.py
+16
-10
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+6
-8
No files found.
colossalai/nn/layer/utils/common.py
View file @
de464504
...
...
@@ -13,17 +13,18 @@ from torch import Tensor, nn
class
CheckpointModule
(
nn
.
Module
):
def
__init__
(
self
,
checkpoint
:
bool
=
True
):
def
__init__
(
self
,
checkpoint
:
bool
=
True
,
offload
:
bool
=
False
):
super
().
__init__
()
self
.
checkpoint
=
checkpoint
self
.
_use_checkpoint
=
checkpoint
self
.
_offload
=
offload
def
_forward
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'CheckpointModule should implement _forward method instead of origin forward'
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
_use_checkpoint
:
return
checkpoint
(
self
.
_forward
,
*
args
,
**
kwargs
)
return
checkpoint
(
self
.
_forward
,
self
.
_offload
,
*
args
,
**
kwargs
)
else
:
return
self
.
_forward
(
*
args
,
**
kwargs
)
...
...
colossalai/utils/__init__.py
View file @
de464504
from
.cuda
import
empty_cache
,
get_current_device
,
set_to_cuda
,
synchronize
from
.activation_checkpoint
import
checkpoint
from
.common
import
(
clip_grad_norm_fp32
,
conditional_context
,
copy_tensor_parallel_attributes
,
count_zeros_fp32
,
...
...
@@ -5,11 +6,11 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
is_no_pp_or_last_stage
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
switch_virtual_pipeline_parallel_rank
,
sync_model_param
)
from
.cuda
import
empty_cache
,
get_current_device
,
set_to_cuda
,
synchronize
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
from
.gradient_accumulation
import
accumulate_gradient
from
.memory
import
report_memory_usage
from
.timer
import
MultiTimer
,
Timer
#from .tensor_detector import TensorDetector
__all__
=
[
'checkpoint'
,
'free_port'
,
'print_rank_0'
,
'sync_model_param'
,
'is_dp_rank_0'
,
'is_tp_rank_0'
,
...
...
@@ -17,5 +18,5 @@ __all__ = [
'is_model_parallel_parameter'
,
'clip_grad_norm_fp32'
,
'count_zeros_fp32'
,
'copy_tensor_parallel_attributes'
,
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'report_memory_usage'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'accumulate_gradient'
,
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'is_moe_parallel_parameter'
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'is_moe_parallel_parameter'
,
'TensorDetector'
]
colossalai/utils/activation_checkpoint.py
View file @
de464504
...
...
@@ -5,14 +5,16 @@ import torch
from
torch.utils.checkpoint
import
check_backward_validity
,
detach_variable
from
colossalai.context.random
import
get_states
,
get_current_mode
,
set_seed_states
,
set_mode
,
sync_states
from
.cuda
import
get_current_device
class
CheckpointFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
*
args
):
def
forward
(
ctx
,
run_function
,
activation_offload
=
False
,
*
args
):
check_backward_validity
(
args
)
ctx
.
run_function
=
run_function
ctx
.
activation_offload
=
activation_offload
ctx
.
device
=
get_current_device
()
# preserve rng states
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -32,7 +34,12 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
torch
.
is_tensor
(
arg
):
tensor_inputs
.
append
(
arg
)
if
ctx
.
activation_offload
:
tmp
=
arg
.
detach
().
cpu
()
tmp
.
requires_grad
=
arg
.
requires_grad
tensor_inputs
.
append
(
tmp
)
else
:
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
inputs
.
append
(
None
)
else
:
...
...
@@ -70,8 +77,9 @@ class CheckpointFunction(torch.autograd.Function):
# Fill in inputs with appropriate saved tensors.
for
i
,
idx
in
enumerate
(
tensor_indices
):
inputs
[
idx
]
=
tensors
[
i
]
tmp
=
tensors
[
i
].
detach
().
to
(
ctx
.
device
)
tmp
.
requires_grad
=
tensors
[
i
].
requires_grad
inputs
[
idx
]
=
tmp
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
if
ctx
.
had_autocast_in_fwd
:
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
():
...
...
@@ -82,7 +90,6 @@ class CheckpointFunction(torch.autograd.Function):
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
(
outputs
,)
# recover the rng states
torch
.
set_rng_state
(
bwd_cpu_rng_state
)
for
parallel_mode
,
state
in
bwd_seed_states
.
items
():
...
...
@@ -103,15 +110,14 @@ class CheckpointFunction(torch.autograd.Function):
torch
.
autograd
.
backward
(
outputs_with_grad
,
args_with_grad
)
grads
=
tuple
(
inp
.
grad
if
isinstance
(
inp
,
torch
.
Tensor
)
else
None
for
inp
in
detached_inputs
)
return
(
None
,)
+
grads
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
*
args
):
def
checkpoint
(
function
,
activation_offload
,
*
args
):
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
:param function: Describe the forward pass function. It should know how to handle the input tuples.
:param args: Tuple containing the parameters of the function
:return: Output of running function with provided args
"""
return
CheckpointFunction
.
apply
(
function
,
*
args
)
return
CheckpointFunction
.
apply
(
function
,
activation_offload
,
*
args
)
tests/test_utils/test_activation_checkpointing.py
View file @
de464504
...
...
@@ -17,13 +17,14 @@ def forward(x, weight):
out_
=
F
.
dropout
(
out
,
p
=
0.4
,
training
=
True
)
return
out_
@
pytest
.
mark
.
gpu
def
test_activation_checkpointing
():
add_seed
(
ParallelMode
.
GLOBAL
,
1024
)
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_activation_checkpointing
(
cpu_offload
):
if
cpu_offload
:
add_seed
(
ParallelMode
.
GLOBAL
,
1024
)
add_seed
(
ParallelMode
.
DATA
,
1026
)
set_mode
(
ParallelMode
.
GLOBAL
)
global_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
add_seed
(
ParallelMode
.
DATA
,
1026
)
set_mode
(
ParallelMode
.
DATA
)
data_parallel_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
set_mode
(
ParallelMode
.
GLOBAL
)
...
...
@@ -49,13 +50,10 @@ def test_activation_checkpointing():
set_mode
(
ParallelMode
.
DATA
)
torch
.
cuda
.
set_rng_state
(
data_parallel_cuda_rng_state
)
set_mode
(
ParallelMode
.
GLOBAL
)
out
=
checkpoint
(
forward
,
data_
,
weight_
)
out
=
checkpoint
(
forward
,
cpu_offload
,
data_
,
weight_
)
loss
=
out
.
sum
()
loss
.
backward
()
assert
torch
.
all
(
data
.
grad
==
data_
.
grad
),
'Gradient of the input does not match'
torch
.
cuda
.
empty_cache
()
if
__name__
==
'__main__'
:
test_activation_checkpointing
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment