Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b2b2a4af
"docs/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "0c0455700fbea15a1ad663b890278dbdb0d581cd"
Unverified
Commit
b2b2a4af
authored
Sep 26, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 26, 2022
Browse files
[autoparallel] adapt solver with mlp (#1638)
parent
04443605
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
166 additions
and
40 deletions
+166
-40
colossalai/auto_parallel/solver/op_handler/__init__.py
colossalai/auto_parallel/solver/op_handler/__init__.py
+2
-1
colossalai/auto_parallel/solver/op_handler/dot_handler.py
colossalai/auto_parallel/solver/op_handler/dot_handler.py
+54
-35
colossalai/auto_parallel/solver/op_handler/operator_handler.py
...salai/auto_parallel/solver/op_handler/operator_handler.py
+17
-4
tests/test_auto_parallel/test_solver_with_mlp.py
tests/test_auto_parallel/test_solver_with_mlp.py
+93
-0
No files found.
colossalai/auto_parallel/solver/op_handler/__init__.py
View file @
b2b2a4af
...
@@ -4,9 +4,10 @@ from .conv_handler import ConvHandler
...
@@ -4,9 +4,10 @@ from .conv_handler import ConvHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.reshape_handler
import
ReshapeHandler
from
.reshape_handler
import
ReshapeHandler
from
.bcast_op_handler
import
BcastOpHandler
from
.bcast_op_handler
import
BcastOpHandler
from
.embedding_handler
import
EmbeddingHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
__all__
=
[
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
,
'BcastOpHandler'
,
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
,
'BcastOpHandler'
,
'UnaryElementwiseHandler'
'UnaryElementwiseHandler'
,
'EmbeddingHandler'
]
]
colossalai/auto_parallel/solver/op_handler/dot_handler.py
View file @
b2b2a4af
...
@@ -410,9 +410,9 @@ class DotHandler(OperatorHandler):
...
@@ -410,9 +410,9 @@ class DotHandler(OperatorHandler):
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
output_data
=
self
.
node
.
_meta_data
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_compute_cost
(
self
,
input_shape
,
weight_shape
):
def
_generate_compute_cost
(
self
,
input_shape
,
weight_shape
,
total_sharding_size
):
# TODO: consider bias addition
# TODO: consider bias addition
compute_cost
=
reduce
(
operator
.
mul
,
input_shape
)
*
weight_shape
[
0
]
*
2
compute_cost
=
reduce
(
operator
.
mul
,
input_shape
)
*
weight_shape
[
0
]
*
2
//
total_sharding_size
return
compute_cost
return
compute_cost
@
exception_handler
@
exception_handler
...
@@ -434,15 +434,17 @@ class DotHandler(OperatorHandler):
...
@@ -434,15 +434,17 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute computation cost
# compute computation cost
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost
# compute the communication cost
# no all-reduce required for this case
communication_cost_activation_backward
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim_1
)
communication_cost
=
0
communication_cost_weight_backward
=
self
.
device_mesh
.
all_reduce_cost
(
weight_memory_cost
,
mesh_dim_0
)
communication_cost
=
communication_cost_activation_backward
+
communication_cost_weight_backward
# create and register strategy
# create and register strategy
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
...
@@ -474,14 +476,17 @@ class DotHandler(OperatorHandler):
...
@@ -474,14 +476,17 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim_1
)
communication_cost_activation_forward
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim_1
)
communication_cost_grad_backward
=
self
.
device_mesh
.
all_reduce_cost
(
weight_memory_cost
,
mesh_dim_0
)
communication_cost
=
communication_cost_activation_forward
+
communication_cost_grad_backward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -508,14 +513,18 @@ class DotHandler(OperatorHandler):
...
@@ -508,14 +513,18 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim_1
)
communication_cost_activation_forward
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim_0
)
communication_cost_activation_backward
=
self
.
device_mesh
.
all_reduce_cost
(
input_grad_memory_cost
,
mesh_dim_1
)
communication_cost
=
communication_cost_activation_backward
+
communication_cost_activation_forward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -542,11 +551,12 @@ class DotHandler(OperatorHandler):
...
@@ -542,11 +551,12 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim
)
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim
)
...
@@ -576,14 +586,16 @@ class DotHandler(OperatorHandler):
...
@@ -576,14 +586,16 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim
)
communication_cost_activation_backward
=
self
.
device_mesh
.
all_reduce_cost
(
input_grad_memory_cost
,
mesh_dim
)
communication_cost
=
communication_cost_activation_backward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -610,14 +622,16 @@ class DotHandler(OperatorHandler):
...
@@ -610,14 +622,16 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
0
communication_cost_weight_backward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
weight_memory_cost
,
0
)
communication_cost
=
communication_cost_weight_backward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -644,14 +658,17 @@ class DotHandler(OperatorHandler):
...
@@ -644,14 +658,17 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
0
)
communication_cost_forward_activation
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
0
)
communication_cost
=
communication_cost_forward_activation
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -678,14 +695,16 @@ class DotHandler(OperatorHandler):
...
@@ -678,14 +695,16 @@ class DotHandler(OperatorHandler):
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
total_sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
,
total_sharding_size
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
=
self
.
_generate_memory_cost
(
toatl_memory_cost
,
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
=
self
.
_generate_memory_cost
(
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
dim_partition_dict_for_input
)
# compute the communication cost of this strategy
# compute the communication cost of this strategy
communication_cost
=
0
communication_cost_activation_backward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
input_grad_memory_cost
,
0
)
communication_cost
=
communication_cost_activation_backward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
...
colossalai/auto_parallel/solver/op_handler/operator_handler.py
View file @
b2b2a4af
...
@@ -64,7 +64,8 @@ class OperatorHandler(ABC):
...
@@ -64,7 +64,8 @@ class OperatorHandler(ABC):
"""
"""
pass
pass
def
_generate_memory_cost
(
self
,
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
):
def
_generate_memory_cost
(
self
,
dim_partition_dict_for_output
,
dim_partition_dict_for_weight
,
sharding_spec_for_input
):
'''
'''
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
...
@@ -102,9 +103,21 @@ class OperatorHandler(ABC):
...
@@ -102,9 +103,21 @@ class OperatorHandler(ABC):
weight_sharding_size
*=
self
.
device_mesh
.
shape
[
mesh_dim
]
weight_sharding_size
*=
self
.
device_mesh
.
shape
[
mesh_dim
]
weight_memory_cost
=
weight_numel
/
weight_sharding_size
*
size_per_elem_bytes
weight_memory_cost
=
weight_numel
/
weight_sharding_size
*
size_per_elem_bytes
total_memory_cost
=
activation_memory_cost
+
weight_memory_cost
# compute the memory cost of input grad
input_grad_numel
=
self
.
input_data
.
numel
()
return
total_memory_cost
,
activation_memory_cost
,
weight_memory_cost
input_grad_sharding_size
=
1
input_grad_mesh_dims
=
[]
for
sharding_dim
,
mesh_dims
in
sharding_spec_for_input
.
items
():
input_grad_mesh_dims
.
extend
(
mesh_dims
)
for
mesh_dim
in
input_grad_mesh_dims
:
input_grad_sharding_size
*=
self
.
device_mesh
.
shape
[
mesh_dim
]
input_grad_memory_cost
=
input_grad_numel
/
input_grad_sharding_size
*
size_per_elem_bytes
memory_cost_forward
=
activation_memory_cost
+
weight_memory_cost
memory_cost_backward
=
input_grad_memory_cost
+
weight_memory_cost
return
(
memory_cost_forward
,
memory_cost_backward
),
activation_memory_cost
,
weight_memory_cost
,
input_grad_memory_cost
def
_generate_resharding_costs
(
self
,
sharding_specs
):
def
_generate_resharding_costs
(
self
,
sharding_specs
):
# The resharding_cost of weight is counted due to sharing weight cases.
# The resharding_cost of weight is counted due to sharing weight cases.
...
...
tests/test_auto_parallel/test_solver_with_mlp.py
0 → 100644
View file @
b2b2a4af
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.solver.cost_graph
import
CostGraph
from
copy
import
deepcopy
from
colossalai.auto_parallel.solver
import
Solver
from
torchvision.models
import
resnet34
,
resnet50
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.options
import
SolverOptions
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
):
super
().
__init__
()
self
.
linear1
=
torch
.
nn
.
Linear
(
dim
,
dim
*
4
)
self
.
linear2
=
torch
.
nn
.
Linear
(
dim
*
4
,
dim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
0
)
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
linear2
(
x
)
return
x
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
def
test_cost_graph
():
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
mesh_shape
=
(
2
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
model
=
MLP
(
32
)
input_sample
=
{
'x'
:
torch
.
rand
(
16
,
32
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
# %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
# %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
# return linear2
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
# # megatron mode if no memory constraints
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
# all sharding on out feature dim if memory budget is not sufficient for megatron mode
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=
5500.0
)
ret
=
solver
.
call_solver_serialized_args
()
strategies_list
=
list
(
ret
[
0
])
computation_cost
=
0
communication_cost
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
graph
.
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
strategies_list
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
compute_cost
communication_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
node_memory_cost
=
node
.
strategies_vector
[
strategies_list
[
index
]].
memory_cost
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
print
(
f
'computation cost is
{
computation_cost
}
'
)
print
(
f
'communication cost is
{
communication_cost
}
'
)
print
(
f
'memory cost is
{
memory_cost
}
'
)
if
__name__
==
'__main__'
:
test_cost_graph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment