Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0908d0fc
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "ed6426c300a4f3ed343ca9d7b4aad1b6199d1469"
Unverified
Commit
0908d0fc
authored
Sep 07, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 07, 2022
Browse files
[autoparallel]add backward cost info into strategies (#1524)
parent
1a359941
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
69 deletions
+148
-69
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+140
-65
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+8
-3
tests/test_auto_parallel/test_conv_handler.py
tests/test_auto_parallel/test_conv_handler.py
+0
-1
No files found.
colossalai/auto_parallel/solver/conv_handler.py
View file @
0908d0fc
...
...
@@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler):
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size
=
self
.
output_data
.
shape
[
2
:]
output_size_product
=
reduce
(
operator
.
mul
,
output_size
,
1
)
input_size
=
self
.
input_data
.
shape
[
2
:]
input_size_product
=
reduce
(
operator
.
mul
,
input_size
,
1
)
kernel_size
=
self
.
weight
.
shape
[
2
:]
kernel_size_product
=
reduce
(
operator
.
mul
,
kernel_size
,
1
)
compute_cost
=
output_size_product
*
bs
*
channel_in
*
channel_out
*
kernel_size_product
forward_compute_cost
=
output_size_product
*
bs
*
channel_in
*
channel_out
*
kernel_size_product
backward_activation_cost
=
input_size_product
*
bs
*
channel_in
*
channel_out
*
kernel_size_product
backward_weight_cost
=
output_size_product
*
bs
*
channel_in
*
channel_out
*
kernel_size_product
compute_cost
=
forward_compute_cost
+
backward_activation_cost
+
backward_weight_cost
return
compute_cost
def
_generate_memory_cost
(
self
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_backward_weight(int): The backward weight will be divided
into sharding_size_backward_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel_output
=
self
.
output_data
.
numel
()
numel_input
=
self
.
input_data
.
numel
()
numel_weight
=
self
.
weight
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# forward memory_cost
memory_cost_forward
=
numel_output
*
size_per_elem_bytes
/
sharding_size_forward
# backward memory_cost
memory_cost_backward_activation
=
numel_input
*
size_per_elem_bytes
/
sharding_size_backward_activation
memory_cost_backward_weight
=
numel_weight
*
size_per_elem_bytes
/
sharding_size_backward_weight
memory_cost_backward
=
memory_cost_backward_activation
+
memory_cost_backward_weight
# memory_cost pair
memory_cost
=
(
memory_cost_forward
,
memory_cost_backward
)
return
memory_cost
,
memory_cost_forward
,
memory_cost_backward_activation
def
split_input_batch_weight_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
...
...
@@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
_
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward
=
0
# compute the backward communication cost of this strategy
communication_cost_backward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_1
)
# total communication cost
communication_cost
=
communication_cost_forward
+
communication_cost_backward
# This strategy do not need to do all_reduce operation
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
...
...
@@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
s
ize_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
s
harding_size_backward_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce operation
in both forward and backward phase.
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
...
...
@@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim_1
)
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward
,
mesh_dim_1
)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward
=
0
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
...
...
@@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim_0
)
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward
,
mesh_dim_0
)
# compute the communication cost of this strategy during backward phase
communication_cost_backward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_1
)
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
...
...
@@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim_0
)
sharding_size_forward
=
1
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
,
memory_cost_forward
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward
,
mesh_dim_0
)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward
=
0
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
...
...
@@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# This strategy do not need to do all_reduce operation
communication_cost
=
0
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size_backward_activation
=
1
sharding_size_backward_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
,
_
,
memory_cost_backward_activation
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward
=
0
# compute the communication cost of this strategy during backward phase
communication_cost_backward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_0
)
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
...
...
@@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
sharding_size_forward
=
1
sharding_size_backward_activation
=
1
sharding_size_backward_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# This strategy do not need to do all_reduce
operation
# This strategy do not need to do all_reduce
in both forward and backward phase
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
...
...
@@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# This strategy do not need to do all_reduce operation
sharding_size_forward
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
sharding_size_backward_weight
=
1
memory_cost
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
...
...
@@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler):
compute_cost
=
self
.
_generate_compute_cost
(
bs
,
channel_in
,
channel_out
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
# compute communication cost
communication_cost
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
memory_cost
,
0
)
sharding_size_forward
=
1
sharding_size_backward_activation
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
sharding_size_backward_weight
=
self
.
device_mesh
.
mesh_shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
mesh_shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_backward_weight
)
# compute communication cost during forward phase
communication_cost_forward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
memory_cost_forward
,
0
)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward
=
0
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
...
...
colossalai/auto_parallel/solver/operator_handler.py
View file @
0908d0fc
...
...
@@ -85,12 +85,17 @@ class OperatorHandler(ABC):
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs
=
{}
for
input_node
,
inpu
t_spec
in
zip
(
self
.
predecessor_node
,
sharding_spec_for_input
):
for
input_node
,
targe
t_spec
in
zip
(
self
.
predecessor_node
,
sharding_spec_for_input
):
resharding_costs
[
input_node
]
=
[]
for
strategy
in
input_node
.
strategies_vector
:
input_sharding_spec
=
strategy
.
output_sharding_spec
assert
isinstance
(
input_sharding_spec
,
ShardingSpec
),
f
'The input node should NOT be a tuple of tensor.'
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
input_spec
)
# compute the resharding cost during forward phase
_
,
_
,
resharding_cost_forward
=
self
.
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
target_spec
)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_
,
_
,
resharding_cost_backward
=
self
.
shape_consistency_manager
.
shape_consistency
(
target_spec
,
input_sharding_spec
)
resharding_cost
=
resharding_cost_forward
+
resharding_cost_backward
resharding_costs
[
input_node
].
append
(
resharding_cost
)
return
resharding_costs
tests/test_auto_parallel/test_conv_handler.py
View file @
0908d0fc
...
...
@@ -82,7 +82,6 @@ def test_conv_handler():
strategies_vector
=
strategies_vector
,
shape_consistency_manager
=
shape_consistency_manager
)
conv_handler
.
register_strategy
()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
conv_handler
.
strategies_vector
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment