Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
36c0f3ea
Unverified
Commit
36c0f3ea
authored
Nov 15, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 15, 2022
Browse files
[autoparallel] remove redundancy comm node (#1893)
parent
9183e0de
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
20 deletions
+23
-20
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+2
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+5
-3
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+5
-3
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
...ensor_shard/node_handler/strategy/batch_norm_generator.py
+8
-8
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+3
-6
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
36c0f3ea
...
...
@@ -81,6 +81,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
continue
for
user_node_index
,
user_node
in
enumerate
(
node
.
strategies_vector
.
successor_nodes
):
if
node
.
sharding_spec
.
sharding_sequence_difference
(
node
.
target_sharding_specs
[
user_node_index
])
==
0
:
continue
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_apply
,
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
36c0f3ea
...
...
@@ -47,6 +47,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
setattr
(
node
,
'target_sharding_specs'
,
target_sharding_specs
)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if
node
.
op
==
'get_attr'
:
...
...
@@ -95,7 +96,8 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
# This stream is created for overlaping the communication and computation.
reduction_stream
=
torch
.
cuda
.
Stream
()
for
node
in
nodes
:
if
node
.
op
==
'call_module'
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
...
...
@@ -122,7 +124,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
def
wrapper
(
param
,
comm_spec
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
)
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
...
...
@@ -172,7 +174,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
def
wrapper
(
param
,
comm_spec
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
)
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
36c0f3ea
...
...
@@ -74,11 +74,13 @@ class NodeHandler(ABC):
if
op_data
.
type
==
OperationDataType
.
PARAM
:
resharding_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
else
:
dtype
=
op_data
.
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
_
,
_
,
resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
prev_sharding_spec
,
current_sharding_spec
)
resharding_cost
=
TrainCycleItem
(
fwd
=
resharding_cost
[
"forward"
],
bwd
=
resharding_cost
[
"backward"
],
total
=
resharding_cost
[
"total"
])
resharding_cost
=
TrainCycleItem
(
fwd
=
resharding_cost
[
"forward"
]
*
size_per_elem_bytes
,
bwd
=
resharding_cost
[
"backward"
]
*
size_per_elem_bytes
,
total
=
resharding_cost
[
"total"
]
*
size_per_elem_bytes
)
resharding_costs
[
node
].
append
(
resharding_cost
)
strategy
.
resharding_costs
=
resharding_costs
return
strategy
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
View file @
36c0f3ea
...
...
@@ -218,7 +218,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
...
...
@@ -254,7 +254,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
...
...
@@ -300,7 +300,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
],
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
...
...
@@ -331,14 +331,14 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# TODO: The strategies below should be uncommented after runtime
# passes ready.
# SR = SR x R WITH SYNC_BN
#
strategy_list.append(self.split_input_batch(0))
#
strategy_list.append(self.split_input_batch(1))
strategy_list
.
append
(
self
.
split_input_batch
(
0
))
strategy_list
.
append
(
self
.
split_input_batch
(
1
))
# SS = SS x S WITH SYNC_BN
#
strategy_list.append(self.split_input_both_dim(0, 1))
#
strategy_list.append(self.split_input_both_dim(1, 0))
strategy_list
.
append
(
self
.
split_input_both_dim
(
0
,
1
))
strategy_list
.
append
(
self
.
split_input_both_dim
(
1
,
0
))
# S01R = S01R x R WITH SYNC_BN
#
strategy_list.append(self.split_input_batch_1d(0, 1))
strategy_list
.
append
(
self
.
split_input_batch_1d
(
0
,
1
))
return
strategy_list
colossalai/tensor/comm_spec.py
View file @
36c0f3ea
...
...
@@ -23,9 +23,7 @@ def _all_gather(tensor, comm_spec):
torch
.
zeros
(
tensor
.
shape
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
comm_spec
.
device_mesh
.
mesh_shape
[
comm_spec
.
logical_process_axis
])
]
tensor
=
tensor
group
=
process_group
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
group
)
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
process_group
)
output
=
torch
.
cat
(
tuple
(
tensor_list
),
comm_spec
.
gather_dim
).
contiguous
()
return
output
...
...
@@ -37,7 +35,6 @@ def _split(tensor, comm_spec):
process_groups_list
=
comm_spec
.
device_mesh
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
_
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
tensor
=
tensor
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
start
=
length
*
rank_list
.
index
(
dist
.
get_rank
())
...
...
@@ -69,7 +66,7 @@ def _all_to_all(tensor, comm_spec):
return
output
def
_all_reduce
(
tensor
,
comm_spec
):
def
_all_reduce
(
tensor
,
comm_spec
,
async_op
=
False
):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
...
...
@@ -78,7 +75,7 @@ def _all_reduce(tensor, comm_spec):
if
dist
.
get_rank
()
in
rank_list
:
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
)
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
async_op
)
return
tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment