Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d164449d
Unverified
Commit
d164449d
authored
Sep 13, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 13, 2022
Browse files
[autoparallel] add resnet autoparallel unit test and add backward weight communication cost (#1589)
parent
7c18a588
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
29 deletions
+168
-29
colossalai/auto_parallel/solver/op_handler/conv_handler.py
colossalai/auto_parallel/solver/op_handler/conv_handler.py
+43
-29
tests/test_auto_parallel/test_solver_with_resnet.py
tests/test_auto_parallel/test_solver_with_resnet.py
+125
-0
No files found.
colossalai/auto_parallel/solver/op_handler/conv_handler.py
View file @
d164449d
...
@@ -103,7 +103,7 @@ class ConvHandler(OperatorHandler):
...
@@ -103,7 +103,7 @@ class ConvHandler(OperatorHandler):
# memory_cost pair
# memory_cost pair
memory_cost
=
(
memory_cost_forward
,
memory_cost_backward
)
memory_cost
=
(
memory_cost_forward
,
memory_cost_backward
)
return
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
return
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
memory_cost_backward_weight
def
split_input_batch_weight_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_batch_weight_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
...
@@ -132,15 +132,18 @@ class ConvHandler(OperatorHandler):
...
@@ -132,15 +132,18 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
_
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
memory_cost
,
_
,
memory_cost_backward_activation
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce operation during forward
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward
=
0
communication_cost_forward
=
0
# compute the backward communication cost of this strategy
# compute the backward communication cost to all reduce the input activation grad
communication_cost_backward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_1
)
communication_cost_backward_activation
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_1
)
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
mesh_dim_0
)
# total communication cost
# total communication cost
communication_cost
=
communication_cost_forward
+
communication_cost_backward
communication_cost
=
communication_cost_forward
+
communication_cost_backward
_activation
+
communication_cost_backward_weight
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
output_sharding_spec
=
sharding_spec_for_output
,
...
@@ -178,11 +181,16 @@ class ConvHandler(OperatorHandler):
...
@@ -178,11 +181,16 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
1
sharding_size_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
memory_cost
,
_
,
_
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_weight
)
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce operation in both forward and backward phase.
# This strategy do not need to do all_reduce operation in forward phase.
communication_cost
=
0
communication_cost_forward
=
0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
mesh_dim_0
)
# compute the total cost
communication_cost
=
communication_cost_forward
+
communication_cost_backward_weight
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -220,15 +228,17 @@ class ConvHandler(OperatorHandler):
...
@@ -220,15 +228,17 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
memory_cost
,
memory_cost_forward_activation
,
_
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_backward_activation
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_weight
)
# compute the communication cost of this strategy during forward phase
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward_activation
,
mesh_dim_1
)
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward_activation
,
mesh_dim_1
)
# This strategy do not need to do all_reduce operation during backward phase
# This strategy do not need to do all_reduce operation to compute the input activation grad
communication_cost_backward
=
0
communication_cost_backward_activation
=
0
communication_cost
=
communication_cost_forward
+
communication_cost_backward
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
mesh_dim_0
)
# compute total cost
communication_cost
=
communication_cost_forward
+
communication_cost_backward_activation
+
communication_cost_backward_weight
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
compute_cost
=
compute_cost
,
...
@@ -265,7 +275,7 @@ class ConvHandler(OperatorHandler):
...
@@ -265,7 +275,7 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# compute the communication cost of this strategy during forward phase
# compute the communication cost of this strategy during forward phase
...
@@ -309,9 +319,8 @@ class ConvHandler(OperatorHandler):
...
@@ -309,9 +319,8 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
1
sharding_size_forward
=
1
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
,
memory_cost_forward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
memory_cost
,
memory_cost_forward_activation
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_backward_activation
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_weight
)
# compute the communication cost of this strategy during forward phase
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward_activation
,
mesh_dim_0
)
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward_activation
,
mesh_dim_0
)
...
@@ -354,7 +363,7 @@ class ConvHandler(OperatorHandler):
...
@@ -354,7 +363,7 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
1
sharding_size_backward_activation
=
1
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
,
_
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
memory_cost
,
_
,
memory_cost_backward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce during forward phase
# This strategy do not need to do all_reduce during forward phase
...
@@ -398,8 +407,8 @@ class ConvHandler(OperatorHandler):
...
@@ -398,8 +407,8 @@ class ConvHandler(OperatorHandler):
sharding_size_forward
=
1
sharding_size_forward
=
1
sharding_size_backward_activation
=
1
sharding_size_backward_activation
=
1
sharding_size_weight
=
1
sharding_size_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
memory_cost
,
_
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_weight
)
# This strategy do not need to do all_reduce in both forward and backward phase
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost
=
0
communication_cost
=
0
...
@@ -441,11 +450,17 @@ class ConvHandler(OperatorHandler):
...
@@ -441,11 +450,17 @@ class ConvHandler(OperatorHandler):
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
mesh_dim_1
]
sharding_size_weight
=
1
sharding_size_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
memory_cost
,
_
,
_
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_weight
)
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce in both forward and backward phase
# This strategy do not need to do all_reduce in forward phase
communication_cost
=
0
communication_cost_forward
=
0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
0
)
# compute the total communication cost
communication_cost
=
communication_cost_backward_weight
+
communication_cost_forward
sharding_strategies
=
ShardingStrategy
(
name
,
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
output_sharding_spec
=
sharding_spec_for_output
,
...
@@ -485,9 +500,8 @@ class ConvHandler(OperatorHandler):
...
@@ -485,9 +500,8 @@ class ConvHandler(OperatorHandler):
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
sharding_size_weight
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
memory_cost
,
memory_cost_forward_activation
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_backward_activation
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
sharding_size_weight
)
# compute communication cost during forward phase
# compute communication cost during forward phase
communication_cost_forward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
communication_cost_forward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
...
...
tests/test_auto_parallel/test_solver_with_resnet.py
0 → 100644
View file @
d164449d
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.solver.cost_graph
import
CostGraph
from
copy
import
deepcopy
from
colossalai.auto_parallel.solver
import
Solver
from
torchvision.models
import
resnet34
,
resnet50
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
class
ConvModel
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_out
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
3
)
self
.
conv2
=
nn
.
Conv2d
(
c_out
,
c_out
,
kernel_size
=
3
)
self
.
conv3
=
nn
.
Conv2d
(
c_out
,
c_out
,
kernel_size
=
3
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
x
*
2
x
=
self
.
conv1
(
x
)
x
=
self
.
conv2
(
x
)
x
=
x
/
2
x
=
self
.
conv3
(
x
)
x
=
self
.
relu
(
x
)
return
x
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
def
test_cost_graph
():
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
mesh_shape
=
(
2
,
4
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
# model = ConvModel(16, 32)
# input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
model
=
resnet50
(
num_classes
=
100000
)
input_sample
=
{
'x'
:
torch
.
rand
(
128
,
3
,
224
,
224
).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
# %maxpool : [#users=2] = call_module[target=maxpool](args = (%relu,), kwargs = {})
# %layer1_0_conv1 : [#users=1] = call_module[target=layer1.0.conv1](args = (%maxpool,), kwargs = {})
# %layer1_0_bn1 : [#users=1] = call_module[target=layer1.0.bn1](args = (%layer1_0_conv1,), kwargs = {})
# %layer1_0_relu : [#users=1] = call_module[target=layer1.0.relu](args = (%layer1_0_bn1,), kwargs = {})
# %layer1_0_conv2 : [#users=1] = call_module[target=layer1.0.conv2](args = (%layer1_0_relu,), kwargs = {})
# %layer1_0_bn2 : [#users=1] = call_module[target=layer1.0.bn2](args = (%layer1_0_conv2,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%layer1_0_bn2, %maxpool), kwargs = {})
# %layer1_0_relu_1 : [#users=2] = call_module[target=layer1.0.relu](args = (%add,), kwargs = {})
# %layer1_1_conv1 : [#users=1] = call_module[target=layer1.1.conv1](args = (%layer1_0_relu_1,), kwargs = {})
# %layer1_1_bn1 : [#users=1] = call_module[target=layer1.1.bn1](args = (%layer1_1_conv1,), kwargs = {})
# %layer1_1_relu : [#users=1] = call_module[target=layer1.1.relu](args = (%layer1_1_bn1,), kwargs = {})
# %layer1_1_conv2 : [#users=1] = call_module[target=layer1.1.conv2](args = (%layer1_1_relu,), kwargs = {})
# %layer1_1_bn2 : [#users=1] = call_module[target=layer1.1.bn2](args = (%layer1_1_conv2,), kwargs = {})
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%layer1_1_bn2, %layer1_0_relu_1), kwargs = {})
# ...
# %avgpool : [#users=1] = call_module[target=avgpool](args = (%layer4_2_relu_1,), kwargs = {})
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%avgpool, 1), kwargs = {})
# %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {})
# return fc
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
# print(len(liveness_dict[0].unique_live_vars))
# assert False
solver_options
=
{
'fast_mode'
:
True
}
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
shape_consistency_manager
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=
1620017824.0
)
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret
=
solver
.
call_solver_serialized_args
()
print
(
ret
)
strategies_list
=
list
(
ret
[
0
])
print
(
strategies_list
)
computation_cost
=
0
communication_cost
=
0
communication_cost_bn
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
graph
.
nodes
):
if
node
.
op
==
'call_module'
:
submod
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
type
(
submod
)
in
ELEMENTWISE_MODULE_OP
:
input_spec
=
node
.
args
[
0
].
strategies_vector
[
strategies_list
[
index
]].
output_sharding_spec
print
(
node
.
name
,
input_spec
)
continue
if
type
(
submod
)
in
BATCHNORM_MODULE_OP
:
communication_cost_bn
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
print
(
node
.
name
,
node
.
strategies_vector
[
strategies_list
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
compute_cost
communication_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
node_memory_cost
=
node
.
strategies_vector
[
strategies_list
[
index
]].
memory_cost
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
print
(
f
'computation cost is
{
computation_cost
}
'
)
print
(
f
'communication cost is
{
communication_cost
}
'
)
print
(
f
'memory cost is
{
memory_cost
}
'
)
print
(
f
'bn communication cost is
{
communication_cost_bn
}
'
)
if
__name__
==
'__main__'
:
test_cost_graph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment