Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
67e1912b
Unverified
Commit
67e1912b
authored
Jan 16, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 16, 2023
Browse files
[autoparallel] support origin activation ckpt on autoprallel system (#2468)
parent
3a21485e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
111 additions
and
5 deletions
+111
-5
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+33
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-0
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+6
-5
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
...s/test_auto_parallel/test_tensor_shard/test_checkpoint.py
+70
-0
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
67e1912b
...
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
if
'activation_checkpoint'
in
user_node
.
meta
:
shape_consistency_node
.
meta
[
'activation_checkpoint'
]
=
user_node
.
meta
[
'activation_checkpoint'
]
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
...
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
if
'activation_checkpoint'
in
node
.
meta
:
comm_spec_apply_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
return
gm
def
_act_annotataion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'activation_checkpoint'
):
from
.runtime_preparation_pass
import
size_processing
user_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
if
'activation_checkpoint'
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'activation_checkpoint'
]
break
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'activation_checkpoint'
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'activation_checkpoint'
]
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'activation_checkpoint'
]
=
user_act_annotation
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
67e1912b
...
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
...
...
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
67e1912b
...
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
'''
def
__init__
(
self
,
module
:
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
def
__init__
(
self
,
module
:
Colo
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
'''
Args:
...
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return
strategies_constructor
def
solve_solution
(
gm
:
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
def
solve_solution
(
gm
:
Colo
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
...
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return
solution
def
transform_to_sharded_model
(
gm
:
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
def
transform_to_sharded_model
(
gm
:
Colo
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
):
'''
This method is used to transform the original graph to the sharded graph.
...
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer
=
ColoTracer
()
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
Colo
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
if
load_solver_solution
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
0 → 100644
View file @
67e1912b
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
HIDDEN_SIZE
=
16
class
GPT2MLPWithCkpt
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
hidden_size
):
super
().
__init__
()
embed_dim
=
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
checkpoint
(
self
.
c_proj
,
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
hidden_states
def
check_act_ckpt
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
GPT2MLPWithCkpt
(
intermediate_size
=
4
*
HIDDEN_SIZE
,
hidden_size
=
HIDDEN_SIZE
)
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
1
,
64
,
HIDDEN_SIZE
).
to
(
'meta'
),
}
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
gm
=
initialize_model
(
model
,
input_sample
,
device_mesh
)
code
=
gm
.
module
.
graph
.
python_code
(
'self'
).
src
assert
"runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')"
in
code
assert
"view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)"
in
code
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
():
world_size
=
4
run_func
=
partial
(
check_act_ckpt
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment