Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd0af9f7
Unverified
Commit
cd0af9f7
authored
Dec 12, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 12, 2022
Browse files
[autoparallel] gpt2lp runtimee test (#2113)
parent
9214d1fe
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
25 deletions
+261
-25
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+45
-23
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
...to_parallel/tensor_shard/solver/strategies_constructor.py
+2
-2
tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
...st_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
+214
-0
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
cd0af9f7
...
...
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationDataType
,
ShardingStrategy
,
)
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
_all_reduce
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
@@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
def
_solution_annotatation
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
]):
def
_solution_annotatation
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
=
None
):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph
=
gm
.
graph
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if
strategies_constructor
is
not
None
:
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
no_strategy_nodes
=
strategies_constructor
.
no_strategy_nodes
else
:
nodes
=
tuple
(
mod_graph
.
nodes
)
no_strategy_nodes
=
[]
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict
=
{}
...
...
@@ -44,6 +55,9 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
for
index
,
node
in
enumerate
(
nodes
):
target_sharding_specs
=
[]
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
if
user_node
in
no_strategy_nodes
:
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
else
:
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
...
...
@@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
dim
+
1
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if
isinstance
(
new_args
[
1
],
int
):
new_args
[
dim
+
1
]
//=
total_shard_size
else
:
new_args
[
1
]
=
list
(
new_args
[
1
])
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
elif
node
.
op
==
'call_function'
:
...
...
@@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
else
:
param_sharded
=
param
setattr
(
target_module
,
name
,
param
_sharded
)
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
setattr
(
target_module
,
name
,
param
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
...
...
@@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param
.
register_hook
(
hook_fn
)
wrapper
(
param
_sharded
,
comm_spec_to_use
)
wrapper
(
param
,
comm_spec_to_use
)
sharded_buffer_dict
=
{}
# apply the sharding spec of buffers
...
...
@@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
target
.
shape
,
{})
setattr
(
target
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
else
:
target_sharded
=
target
setattr
(
target_module
,
atoms
[
-
1
],
target_sharded
)
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
assert
hasattr
(
target_module
,
atoms
[
-
1
])
setattr
(
target_module
,
atoms
[
-
1
],
target
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
...
...
@@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param
.
register_hook
(
hook_fn
)
wrapper
(
target
_sharded
,
comm_spec_to_use
)
wrapper
(
target
,
comm_spec_to_use
)
return
gm
...
...
@@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
pass
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
=
None
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
solution
)
gm
,
solution
,
strategies_constructor
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
...
...
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
View file @
cd0af9f7
...
...
@@ -41,6 +41,7 @@ class StrategiesConstructor:
self
.
leaf_strategies
=
[]
self
.
strategy_map
=
{}
self
.
solver_options
=
solver_options
self
.
no_strategy_nodes
=
[]
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
'''
...
...
@@ -78,12 +79,11 @@ class StrategiesConstructor:
return
_check_no_strategy_for_data
(
node
.
_meta_data
)
no_strategy_node
=
[]
for
node
in
self
.
nodes
:
strategies_vector
=
StrategiesVector
(
node
)
if
_check_no_strategy_for_node
(
node
):
no_strategy_node
.
append
(
node
)
self
.
no_strategy_node
s
.
append
(
node
)
pass
# placeholder node
...
...
tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
0 → 100644
View file @
cd0af9f7
import
copy
import
random
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
transformers
from
torch.fx
import
GraphModule
from
transformers.activations
import
ACT2FN
from
transformers.models.gpt2.modeling_gpt2
import
GPT2MLP
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
BATCH_SIZE
=
1
SEQ_LENGTH
=
32
HIDDEN_DIM
=
768
seed
=
128
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
class
GPT2MLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
config
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
ACT2FN
[
config
.
activation_function
]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
c_proj
(
hidden_states
)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return
hidden_states
def
check_mlp_layer
(
rank
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
).
to
(
'cuda'
)
input
=
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'cuda'
)
test_model
=
copy
.
deepcopy
(
model
)
test_input
=
copy
.
deepcopy
(
input
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
print
(
graph
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
print
(
gm
)
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
cpu_rng_state
=
torch
.
get_rng_state
()
origin_output
=
test_model
(
test_input
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
torch
.
set_rng_state
(
cpu_rng_state
)
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
origin_output
,
rtol
=
1e-03
,
atol
=
1e-04
)
#*******************backward starting*******************
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
output
.
sum
().
backward
()
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
origin_output
.
sum
().
backward
()
origin_param_dict
=
dict
(
test_model
.
named_parameters
())
if
rank
==
0
:
print
(
"*******************backward starting*******************"
)
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
print
(
name
,
param_grad
,
origin_param_grad
)
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
0
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
print
(
"*******************backward finished*******************"
)
if
rank
==
1
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
origin_param_size
//
2
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
if
rank
==
2
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
0
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
if
rank
==
3
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
origin_param_size
//
2
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
#*******************backward finished*******************
#*******************strategy selected*******************
if
rank
==
0
:
print
(
"*******************strategy selected*******************"
)
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
computation_cost
=
0
communication_cost
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
strategies_list
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
compute_cost
.
total
communication_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
.
total
node_memory_cost
=
node
.
strategies_vector
[
strategies_list
[
index
]].
memory_cost
.
total
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
.
activation
+
node_memory_cost
.
parameter
print
(
f
'computation cost is
{
computation_cost
}
'
)
print
(
f
'communication cost is
{
communication_cost
}
'
)
print
(
f
'memory cost is
{
memory_cost
}
'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'model_cls'
,
[
GPT2MLP
])
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
(
model_cls
):
world_size
=
4
run_func
=
partial
(
check_mlp_layer
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment