Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
67e1912b
Unverified
Commit
67e1912b
authored
Jan 16, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 16, 2023
Browse files
[autoparallel] support origin activation ckpt on autoprallel system (#2468)
parent
3a21485e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
111 additions
and
5 deletions
+111
-5
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+33
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-0
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+6
-5
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
...s/test_auto_parallel/test_tensor_shard/test_checkpoint.py
+70
-0
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
67e1912b
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply
,
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
node_to_index_dict
[
node
],
user_node_index
))
if
'activation_checkpoint'
in
user_node
.
meta
:
shape_consistency_node
.
meta
[
'activation_checkpoint'
]
=
user_node
.
meta
[
'activation_checkpoint'
]
new_args
=
list
(
user_node
.
args
)
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
user
.
kwargs
=
new_kwargs
if
'activation_checkpoint'
in
node
.
meta
:
comm_spec_apply_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
return
gm
def
_act_annotataion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'activation_checkpoint'
):
from
.runtime_preparation_pass
import
size_processing
user_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
if
'activation_checkpoint'
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'activation_checkpoint'
]
break
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'activation_checkpoint'
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'activation_checkpoint'
]
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'activation_checkpoint'
]
=
user_act_annotation
return
gm
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
67e1912b
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
for
user
in
user_list
:
...
...
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
67e1912b
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
)
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
into the forward function.
'''
'''
def
__init__
(
self
,
module
:
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
def
__init__
(
self
,
module
:
Colo
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
'''
'''
Args:
Args:
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return
strategies_constructor
return
strategies_constructor
def
solve_solution
(
gm
:
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
def
solve_solution
(
gm
:
Colo
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
'''
'''
This method is used to solve the best solution for the given graph.
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return
solution
return
solution
def
transform_to_sharded_model
(
gm
:
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
def
transform_to_sharded_model
(
gm
:
Colo
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
):
strategies_constructor
:
StrategiesConstructor
):
'''
'''
This method is used to transform the original graph to the sharded graph.
This method is used to transform the original graph to the sharded graph.
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
return a series of integers, but return the best strategies.
'''
'''
tracer
=
ColoTracer
()
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
Colo
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
.
recompile
()
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
if
load_solver_solution
:
if
load_solver_solution
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
0 → 100644
View file @
67e1912b
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
HIDDEN_SIZE
=
16
class
GPT2MLPWithCkpt
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
hidden_size
):
super
().
__init__
()
embed_dim
=
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
checkpoint
(
self
.
c_proj
,
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
hidden_states
def
check_act_ckpt
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
GPT2MLPWithCkpt
(
intermediate_size
=
4
*
HIDDEN_SIZE
,
hidden_size
=
HIDDEN_SIZE
)
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
1
,
64
,
HIDDEN_SIZE
).
to
(
'meta'
),
}
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
gm
=
initialize_model
(
model
,
input_sample
,
device_mesh
)
code
=
gm
.
module
.
graph
.
python_code
(
'self'
).
src
assert
"runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')"
in
code
assert
"view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)"
in
code
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
():
world_size
=
4
run_func
=
partial
(
check_act_ckpt
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment