Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd0af9f7
Unverified
Commit
cd0af9f7
authored
Dec 12, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 12, 2022
Browse files
[autoparallel] gpt2lp runtimee test (#2113)
parent
9214d1fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
25 deletions
+261
-25
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+45
-23
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
...to_parallel/tensor_shard/solver/strategies_constructor.py
+2
-2
tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
...st_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
+214
-0
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
cd0af9f7
...
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
...
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationDataType
,
OperationDataType
,
ShardingStrategy
,
ShardingStrategy
,
)
)
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
_all_reduce
from
colossalai.tensor.comm_spec
import
_all_reduce
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
@@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
...
@@ -19,13 +20,23 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
shape_consistency_manager
=
ShapeConsistencyManager
()
def
_solution_annotatation
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
]):
def
_solution_annotatation
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
=
None
):
"""
"""
This method is used to stick the solution strategy to the nodes and add the information
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
required in runtime into graph as placeholder nodes.
"""
"""
mod_graph
=
gm
.
graph
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if
strategies_constructor
is
not
None
:
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
no_strategy_nodes
=
strategies_constructor
.
no_strategy_nodes
else
:
nodes
=
tuple
(
mod_graph
.
nodes
)
no_strategy_nodes
=
[]
# the dict to get origin sharding spec of node
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict
=
{}
origin_node_sharding_spec_dict
=
{}
...
@@ -44,7 +55,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
...
@@ -44,7 +55,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
for
index
,
node
in
enumerate
(
nodes
):
for
index
,
node
in
enumerate
(
nodes
):
target_sharding_specs
=
[]
target_sharding_specs
=
[]
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
if
user_node
in
no_strategy_nodes
:
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
else
:
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
setattr
(
node
,
'target_sharding_specs'
,
target_sharding_specs
)
setattr
(
node
,
'target_sharding_specs'
,
target_sharding_specs
)
...
@@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -136,13 +150,17 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
new_args
.
append
(
arg
)
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
dim
+
1
]
==
-
1
:
continue
total_shard_size
=
1
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if
isinstance
(
new_args
[
1
],
int
):
new_args
[
dim
+
1
]
//=
total_shard_size
else
:
new_args
[
1
]
=
list
(
new_args
[
1
])
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
node
.
args
=
tuple
(
new_args
)
elif
node
.
op
==
'call_function'
:
elif
node
.
op
==
'call_function'
:
...
@@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -193,12 +211,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
# TODO: build a ColoParamter class to manager the distributed parameters
param_sharded
=
torch
.
nn
.
Parameter
(
# we could use .data here, because all the operations just happen before the real training
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
# loop, so we don't need to track these operations in the autograd graph.
target_sharding_spec
).
detach
().
clone
())
param
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
else
:
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
param_sharded
=
param
setattr
(
target_module
,
name
,
param
_sharded
)
setattr
(
target_module
,
name
,
param
)
comm_actions
=
node
.
best_strategy
.
communication_actions
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
comm_spec_to_use
=
comm_action
.
comm_spec
...
@@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -212,7 +230,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param
_sharded
,
comm_spec_to_use
)
wrapper
(
param
,
comm_spec_to_use
)
sharded_buffer_dict
=
{}
sharded_buffer_dict
=
{}
# apply the sharding spec of buffers
# apply the sharding spec of buffers
...
@@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -242,12 +260,13 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
target
.
shape
,
{})
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
target
.
shape
,
{})
setattr
(
target
,
'sharding_spec'
,
origin_sharding_spec
)
setattr
(
target
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded
=
torch
.
nn
.
Parameter
(
# we could use .data here, because all the operations just happen before the real training
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
# loop, so we don't need to track these operations in the autograd graph.
target_sharding_spec
).
detach
().
clone
())
target
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
else
:
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
target_sharded
=
target
setattr
(
target_module
,
atoms
[
-
1
],
target_sharded
)
assert
hasattr
(
target_module
,
atoms
[
-
1
])
setattr
(
target_module
,
atoms
[
-
1
],
target
)
comm_actions
=
node
.
best_strategy
.
communication_actions
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
for
operation_data
,
comm_action
in
comm_actions
.
items
():
...
@@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -262,7 +281,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
target
_sharded
,
comm_spec_to_use
)
wrapper
(
target
,
comm_spec_to_use
)
return
gm
return
gm
...
@@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
...
@@ -273,9 +292,12 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
pass
pass
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
=
None
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
solution
)
gm
,
solution
,
strategies_constructor
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
# gm = implicit_comm_action_apply(gm)
...
...
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
View file @
cd0af9f7
...
@@ -41,6 +41,7 @@ class StrategiesConstructor:
...
@@ -41,6 +41,7 @@ class StrategiesConstructor:
self
.
leaf_strategies
=
[]
self
.
leaf_strategies
=
[]
self
.
strategy_map
=
{}
self
.
strategy_map
=
{}
self
.
solver_options
=
solver_options
self
.
solver_options
=
solver_options
self
.
no_strategy_nodes
=
[]
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
'''
'''
...
@@ -78,12 +79,11 @@ class StrategiesConstructor:
...
@@ -78,12 +79,11 @@ class StrategiesConstructor:
return
_check_no_strategy_for_data
(
node
.
_meta_data
)
return
_check_no_strategy_for_data
(
node
.
_meta_data
)
no_strategy_node
=
[]
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
strategies_vector
=
StrategiesVector
(
node
)
strategies_vector
=
StrategiesVector
(
node
)
if
_check_no_strategy_for_node
(
node
):
if
_check_no_strategy_for_node
(
node
):
no_strategy_node
.
append
(
node
)
self
.
no_strategy_node
s
.
append
(
node
)
pass
pass
# placeholder node
# placeholder node
...
...
tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
0 → 100644
View file @
cd0af9f7
import
copy
import
random
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
numpy
as
np
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
transformers
from
torch.fx
import
GraphModule
from
transformers.activations
import
ACT2FN
from
transformers.models.gpt2.modeling_gpt2
import
GPT2MLP
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
BATCH_SIZE
=
1
SEQ_LENGTH
=
32
HIDDEN_DIM
=
768
seed
=
128
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
class
GPT2MLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
config
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
ACT2FN
[
config
.
activation_function
]
# We temporarily banned the Dropout layer because the rng state need
# to process to get the correct result.
# self.dropout = nn.Dropout(config.resid_pdrop)
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
c_proj
(
hidden_states
)
# TODO: the rng state need to be fixed for distributed runtime
# hidden_states = self.dropout(hidden_states)
return
hidden_states
def
check_mlp_layer
(
rank
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
).
to
(
'cuda'
)
input
=
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'cuda'
)
test_model
=
copy
.
deepcopy
(
model
)
test_input
=
copy
.
deepcopy
(
input
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
print
(
graph
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
print
(
gm
)
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
cpu_rng_state
=
torch
.
get_rng_state
()
origin_output
=
test_model
(
test_input
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
torch
.
set_rng_state
(
cpu_rng_state
)
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
origin_output
,
rtol
=
1e-03
,
atol
=
1e-04
)
#*******************backward starting*******************
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
output
.
sum
().
backward
()
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
origin_output
.
sum
().
backward
()
origin_param_dict
=
dict
(
test_model
.
named_parameters
())
if
rank
==
0
:
print
(
"*******************backward starting*******************"
)
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
print
(
name
,
param_grad
,
origin_param_grad
)
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
0
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
print
(
"*******************backward finished*******************"
)
if
rank
==
1
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
origin_param_size
//
2
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
if
rank
==
2
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
0
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
if
rank
==
3
:
for
name
,
param
in
model
.
named_parameters
():
param_grad
=
param
.
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_size
=
origin_param_grad
.
shape
[
-
1
]
if
name
==
'c_fc.bias'
:
assert_close_loose
(
param_grad
,
origin_param_grad
.
narrow
(
0
,
origin_param_size
//
2
,
origin_param_size
//
2
),
rtol
=
1e-03
,
atol
=
1e-03
)
else
:
assert_close_loose
(
param_grad
,
origin_param_grad
,
rtol
=
1e-03
,
atol
=
1e-03
)
#*******************backward finished*******************
#*******************strategy selected*******************
if
rank
==
0
:
print
(
"*******************strategy selected*******************"
)
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
computation_cost
=
0
communication_cost
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
strategies_list
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
compute_cost
.
total
communication_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
.
total
node_memory_cost
=
node
.
strategies_vector
[
strategies_list
[
index
]].
memory_cost
.
total
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
.
activation
+
node_memory_cost
.
parameter
print
(
f
'computation cost is
{
computation_cost
}
'
)
print
(
f
'communication cost is
{
communication_cost
}
'
)
print
(
f
'memory cost is
{
memory_cost
}
'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'model_cls'
,
[
GPT2MLP
])
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
(
model_cls
):
world_size
=
4
run_func
=
partial
(
check_mlp_layer
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment