Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb2c6a24
Unverified
Commit
cb2c6a24
authored
Feb 15, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 15, 2023
Browse files
[autoparallel] refactor runtime pass (#2644)
* [autoparallel] refactor runtime pass * add unit test * polish
parent
89f8975f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
351 additions
and
213 deletions
+351
-213
colossalai/auto_parallel/passes/constants.py
colossalai/auto_parallel/passes/constants.py
+5
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+227
-210
tests/test_auto_parallel/test_pass/test_node_converting_pass.py
...test_auto_parallel/test_pass/test_node_converting_pass.py
+54
-0
tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
...uto_parallel/test_pass/test_size_value_converting_pass.py
+65
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+0
-3
No files found.
colossalai/auto_parallel/passes/constants.py
View file @
cb2c6a24
...
...
@@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch
.
nn
.
ReLU
,
torch
.
nn
.
Softmax
,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS
=
[
torch
.
Tensor
.
view
,
torch
.
Tensor
.
reshape
,
torch
.
reshape
]
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
cb2c6a24
...
...
@@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.constants
import
SHAPE_ARGUMENT_OPS
shape_consistency_manager
=
ShapeConsistencyManager
()
...
...
@@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
return
size
def
_solution_annotatation
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
=
None
):
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph
=
gm
.
graph
# TODO: In future PR, strategies_constructor should be a required argument,
# instead of optional argument. This is because we don't need to consider nodes with
# no strategy in runtime preparation pass.
if
strategies_constructor
is
not
None
:
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
no_strategy_nodes
=
strategies_constructor
.
no_strategy_nodes
else
:
nodes
=
tuple
(
mod_graph
.
nodes
)
no_strategy_nodes
=
[]
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
no_strategy_nodes
=
strategies_constructor
.
no_strategy_nodes
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict
=
{}
...
...
@@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
setattr
(
node
,
'target_sharding_specs'
,
target_sharding_specs
)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if
node
.
op
==
'get_attr'
:
...
...
@@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
def
_
size_value_converting
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
def
size_value_converting
_pass
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
...
...
@@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
nodes
=
tuple
(
mod_graph
.
nodes
)
node_pairs
=
{}
# DeviceMesh information instructs the scaling of the size value
device_mesh_info
=
{}
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
mesh_shape
):
device_mesh_info
[
dim
]
=
dim_size
def
_extract_target_dim
(
node
):
'''
A helper function to etract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
2. tensor.size(dim)
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
target_dim
=
None
if
len
(
node
.
args
)
>
1
:
target_dim
=
node
.
args
[
1
]
if
target_dim
<
0
:
target_dim
+=
node
.
args
[
0
].
_meta_data
.
dim
()
return
target_dim
def
_post_processing
(
node
,
size_processing_node
):
'''
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
if
user
==
size_processing_node
:
continue
new_args
=
list
(
user
.
args
)
new_kwargs
=
dict
(
user
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
if
node
in
new_args
:
# substitute the origin node with size_processing_node
new_args
[
new_args
.
index
(
node
)]
=
size_processing_node
user
.
args
=
tuple
(
new_args
)
elif
str
(
node
)
in
new_kwargs
:
# substitute the origin node with size_processing_node
new_kwargs
[
str
(
node
)]
=
size_processing_node
user
.
kwargs
=
new_kwargs
def
_update_slice_object_args
(
slice_object
):
'''
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
if
isinstance
(
slice_object
,
slice
):
start
=
slice_object
.
start
stop
=
slice_object
.
stop
step
=
slice_object
.
step
if
start
in
node_pairs
:
start
=
node_pairs
[
start
]
if
stop
in
node_pairs
:
stop
=
node_pairs
[
stop
]
if
step
in
node_pairs
:
step
=
node_pairs
[
step
]
return
slice
(
start
,
stop
,
step
)
elif
isinstance
(
slice_object
,
int
):
if
slice_object
in
node_pairs
:
return
node_pairs
[
slice_object
]
else
:
return
slice_object
else
:
raise
RuntimeError
(
f
"Unsupported slice object type:
{
type
(
slice_object
)
}
"
)
for
node
in
nodes
:
if
node
.
op
==
'call_method'
and
node
.
target
==
'size'
:
...
...
@@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
sharding_spec
=
node
.
args
[
0
].
sharding_spec
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# there are two usages of torch.Tensor.size:
# tensor.size()
# tensor.size(dim)
# if a target_dim is assigned, then the output will be
# in type of int, instead of torch.Size
target_dim
=
None
if
len
(
node
.
args
)
>
1
:
target_dim
=
node
.
args
[
1
]
if
target_dim
<
0
:
target_dim
+=
node
.
args
[
0
].
_meta_data
.
dim
()
# DeviceMesh information instructs the scaling of the size value
device_mesh_info
=
{}
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
mesh_shape
):
device_mesh_info
[
dim
]
=
dim_size
target_dim
=
_extract_target_dim
(
node
)
# insert size_processing node
with
mod_graph
.
inserting_after
(
node
):
size_processing_node
=
mod_graph
.
create_node
(
'call_function'
,
size_processing
,
args
=
(
node
,
dim_partition_dict
,
device_mesh_info
,
target_dim
,
node
.
name
))
# store original node and processing node pair in node_pairs dictioanry
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
if
user
==
size_processing_node
:
continue
new_args
=
list
(
user
.
args
)
new_kwargs
=
dict
(
user
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
if
node
in
new_args
:
# substitute the origin node with size_processing_node
new_args
[
new_args
.
index
(
node
)]
=
size_processing_node
user
.
args
=
tuple
(
new_args
)
elif
str
(
node
)
in
new_kwargs
:
# substitute the origin node with size_processing_node
new_kwargs
[
str
(
node
)]
=
size_processing_node
user
.
kwargs
=
new_kwargs
_post_processing
(
node
,
size_processing_node
)
if
node
.
op
==
'call_function'
and
node
.
target
==
operator
.
getitem
:
...
...
@@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if
isinstance
(
getitem_index
,
slice
):
new_start
,
new_stop
,
new_step
=
getitem_index
.
start
,
getitem_index
.
stop
,
getitem_index
.
step
if
getitem_index
.
start
in
node_pairs
:
new_start
=
node_pairs
[
getitem_index
.
start
]
elif
getitem_index
.
stop
in
node_pairs
:
new_stop
=
node_pairs
[
getitem_index
.
stop
]
elif
getitem_index
.
step
in
node_pairs
:
new_step
=
node_pairs
[
getitem_index
.
step
]
new_slice_item
=
slice
(
new_start
,
new_stop
,
new_step
)
new_slice_item
=
_update_slice_object_args
(
getitem_index
)
new_args
=
(
node
.
args
[
0
],
new_slice_item
)
node
.
args
=
new_args
...
...
@@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if
slice_item
is
None
:
new_slice_items
.
append
(
None
)
continue
new_start
,
new_stop
,
new_step
=
slice_item
.
start
,
slice_item
.
stop
,
slice_item
.
step
if
slice_item
.
start
in
node_pairs
:
new_start
=
node_pairs
[
slice_item
.
start
]
elif
slice_item
.
stop
in
node_pairs
:
new_stop
=
node_pairs
[
slice_item
.
stop
]
elif
slice_item
.
step
in
node_pairs
:
new_step
=
node_pairs
[
slice_item
.
step
]
new_slice_item
=
slice
(
new_start
,
new_stop
,
new_step
)
new_slice_item
=
_update_slice_object_args
(
slice_item
)
new_slice_items
.
append
(
new_slice_item
)
new_args
=
(
node
.
args
[
0
],
tuple
(
new_slice_items
))
...
...
@@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return
gm
def
_
node_args_converting
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
def
node_args_converting
_pass
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'sharding_spec'
):
continue
def
_process_sharding_spec
(
sharding_spec
):
if
isinstance
(
sharding_spec
,
ShardingSpec
):
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
device_mesh
=
sharding_spec
.
device_mesh
return
dim_partition_dict
,
device_mesh
if
sharding_spec
is
None
:
return
None
,
None
assert
isinstance
(
sharding_spec
,
(
tuple
,
list
)),
'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh
=
sharding_spec
[
0
].
device_mesh
dim_partition_dict
=
[]
for
element
in
sharding_spec
:
dim_partition_dict
.
append
(
_process_sharding_spec
(
element
))
return
dim_partition_dict
,
sharding_spec
output_dim_partition_dict
,
device_mesh
=
_process_sharding_spec
(
node
.
sharding_spec
)
def
_extract_info_from_sharding_spec
(
sharding_spec
):
'''
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
if
isinstance
(
sharding_spec
,
ShardingSpec
):
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
device_mesh
=
sharding_spec
.
device_mesh
return
dim_partition_dict
,
device_mesh
if
sharding_spec
is
None
:
return
None
,
None
assert
isinstance
(
sharding_spec
,
(
tuple
,
list
)),
'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh
=
sharding_spec
[
0
].
device_mesh
dim_partition_dict
=
[]
for
element
in
sharding_spec
:
dim_partition_dict
.
append
(
_extract_info_from_sharding_spec
(
element
))
return
dim_partition_dict
,
sharding_spec
def
_process_node_arguments
(
node
):
new_args
=
[]
for
arg
in
node
.
args
:
# There are two args style:
# 1. (input, *shape)
# 2. (input, shape)
# We will extract the elements from shape and add them into the new_args
# Finally, the args style of new_args will be unified to (input, *shape)
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
new_args
.
extend
(
arg
.
_meta_data
)
elif
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
int
,
tuple
,
list
)),
'The argument in view node should be either type of Node or int.'
if
isinstance
(
arg
,
(
tuple
,
list
)):
new_args
.
extend
(
arg
)
else
:
new_args
.
append
(
arg
)
return
new_args
def
_scale_args_adapt_sharding_spec
(
dim_partition_dict
,
device_mesh
,
node
):
new_args
=
_process_node_arguments
(
node
)
if
node
.
op
==
'call_method'
:
args_to_process
=
list
(
new_args
[
1
:])
else
:
args_to_process
=
list
(
new_args
)
for
dim
,
shard_dims
in
dim_partition_dict
.
items
():
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
# we will skip the dim with -1 value
if
args_to_process
[
dim
]
==
-
1
:
continue
else
:
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
args_to_process
[
dim
]
//=
total_shard_size
args_to_process
=
tuple
(
args_to_process
)
if
node
.
op
==
'call_method'
:
method
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
# process the node with (input, *shape) style args
if
method
in
(
torch
.
Tensor
.
view
,
torch
.
Tensor
.
reshape
):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
int
,
tuple
,
list
)):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
int
,
tuple
,
list
)),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
# There are two ways to use torch.view:
# 1. torch.view(input, *shape)
# 2. torch.view(input, shape)
if
isinstance
(
new_args
[
1
],
int
):
# we will skip the dim with -1 value
if
new_args
[
dim
+
1
]
==
-
1
:
continue
else
:
new_args
[
dim
+
1
]
//=
total_shard_size
else
:
new_args
[
1
]
=
list
(
new_args
[
1
])
# we will skip the dim with -1 value
if
new_args
[
1
][
dim
]
==
-
1
:
continue
else
:
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
new_args
=
(
new_args
[
0
],)
+
args_to_process
else
:
new_args
=
args_to_process
node
.
args
=
new_args
def
_filter_node_with_shape_args
(
node
):
if
node
.
op
==
'call_method'
:
target
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
elif
node
.
op
==
'call_function'
:
target
=
node
.
target
# process the node with (input, torch.Size) style args
if
target
in
(
torch
.
reshape
,):
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
new_args
.
append
(
list
(
arg
.
_meta_data
))
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in reshape node should be either type of Node or tuple.'
new_args
.
append
(
list
(
arg
))
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
# we will skip the dim with -1 value
if
new_args
[
1
][
dim
]
==
-
1
:
continue
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
1
][
dim
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
else
:
target
=
None
if
target
in
SHAPE_ARGUMENT_OPS
:
return
True
return
False
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'sharding_spec'
):
continue
output_dim_partition_dict
,
device_mesh
=
_extract_info_from_sharding_spec
(
node
.
sharding_spec
)
if
_filter_node_with_shape_args
(
node
):
_scale_args_adapt_sharding_spec
(
output_dim_partition_dict
,
device_mesh
,
node
)
return
gm
def
_
module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
,
overlap
=
False
):
def
module_params_sharding
_pass
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
,
overlap
=
False
):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
...
...
@@ -361,6 +386,49 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
nodes
=
tuple
(
mod_graph
.
nodes
)
# This stream is created for overlaping the communication and computation.
reduction_stream
=
torch
.
cuda
.
Stream
()
def
_add_hook_for_grad_communication
(
node
,
param
):
comm_actions
=
node
.
best_strategy
.
communication_actions
def
_filter_param_to_hook
(
node
,
op_data
,
comm_action
):
if
node
.
op
==
'call_module'
and
op_data
.
type
==
OperationDataType
.
PARAM
and
op_data
.
name
==
param
.
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
return
True
if
node
.
op
==
'get_attr'
and
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
return
True
return
False
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
# register hook to the parameters
if
_filter_param_to_hook
(
node
,
operation_data
,
comm_action
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
if
overlap
:
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
else
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param
,
comm_spec_to_use
,
reduction_stream
,
overlap
=
overlap
)
def
_shard_param
(
param
,
target_sharding_spec
):
# apply the sharding spec of parameters
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
for
node
in
nodes
:
if
node
.
op
==
'call_module'
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
...
...
@@ -370,36 +438,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
setattr
(
target_module
,
'processed'
,
True
)
for
name
,
param
in
target_module
.
named_parameters
():
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
# apply the sharding spec of parameters
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
_shard_param
(
param
,
target_sharding_spec
)
setattr
(
target_module
,
name
,
param
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
# register hook to the parameters
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
if
overlap
:
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
else
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param
,
comm_spec_to_use
,
reduction_stream
,
overlap
=
overlap
)
_add_hook_for_grad_communication
(
node
,
param
)
sharded_buffer_dict
=
{}
# apply the sharding spec of buffers
...
...
@@ -427,37 +469,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
target
=
getattr
(
target_module
,
atoms
[
-
1
])
target_sharding_spec
=
node
.
sharding_spec
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
target
.
shape
,
{})
setattr
(
target
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
target
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
_shard_param
(
target
,
target_sharding_spec
)
assert
hasattr
(
target_module
,
atoms
[
-
1
])
setattr
(
target_module
,
atoms
[
-
1
],
target
)
_add_hook_for_grad_communication
(
node
,
target
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
# register hook to the parameters
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
if
overlap
:
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
else
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
wrapper
(
target
,
comm_spec_to_use
,
reduction_stream
,
overlap
=
overlap
)
return
gm
...
...
@@ -471,14 +488,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
=
None
,
strategies_constructor
:
StrategiesConstructor
,
overlap
=
False
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_
solution_annotatation
(
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotatation
_pass
(
gm
,
solution
,
strategies_constructor
)
gm
=
_
size_value_converting
(
gm
,
device_mesh
)
gm
=
_
node_args_converting
(
gm
,
device_mesh
)
gm
=
size_value_converting
_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting
_pass
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm
=
_
module_params_sharding
(
gm
,
device_mesh
,
overlap
=
overlap
)
gm
=
module_params_sharding
_pass
(
gm
,
device_mesh
,
overlap
=
overlap
)
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
tests/test_auto_parallel/test_pass/test_node_converting_pass.py
0 → 100644
View file @
cb2c6a24
import
torch
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
node_args_converting_pass
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
class
TestModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
x
=
x
.
view
(
4
,
4
,
2
)
return
x
def
insert_narrow
(
gm
,
x_node
):
graph
=
gm
.
graph
with
graph
.
inserting_after
(
x_node
):
shard_node
=
graph
.
create_node
(
'call_method'
,
'narrow'
,
args
=
(
x_node
,
0
,
0
,
2
),
kwargs
=
{})
view_node
=
list
(
x_node
.
users
.
keys
())[
0
]
new_args
=
list
(
view_node
.
args
)
new_args
[
0
]
=
shard_node
view_node
.
args
=
tuple
(
new_args
)
return
gm
def
test_node_args_converting_pass
():
model
=
TestModule
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
8
).
to
(
'meta'
)}
input
=
torch
.
rand
(
4
,
8
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
x_node
=
list
(
graph
.
nodes
)[
0
]
view_node
=
list
(
graph
.
nodes
)[
1
]
sharding_spec
=
ShardingSpec
(
device_mesh
,
entire_shape
=
(
4
,
8
),
dim_partition_dict
=
{
0
:
[
0
]})
setattr
(
x_node
,
'sharding_spec'
,
sharding_spec
)
setattr
(
view_node
,
'sharding_spec'
,
sharding_spec
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
gm
=
insert_narrow
(
gm
,
x_node
)
gm
.
recompile
()
output
=
gm
(
input
)
assert
output
.
shape
==
torch
.
Size
([
2
,
4
,
2
])
if
__name__
==
'__main__'
:
test_node_args_converting_pass
()
tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
0 → 100644
View file @
cb2c6a24
import
torch
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
size_value_converting_pass
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
class
TestModule
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
size
=
x
.
size
()
return
size
def
insert_narrow
(
gm
,
x_node
):
graph
=
gm
.
graph
with
graph
.
inserting_after
(
x_node
):
shard_node
=
graph
.
create_node
(
'call_method'
,
'narrow'
,
args
=
(
x_node
,
0
,
0
,
2
),
kwargs
=
{})
size_node
=
list
(
x_node
.
users
.
keys
())[
0
]
size_node
.
args
=
(
shard_node
,)
return
gm
def
recover_narrow
(
gm
,
narrow_node
):
graph
=
gm
.
graph
size_node
=
list
(
graph
.
nodes
)[
2
]
x_node
=
narrow_node
.
args
[
0
]
size_node
.
args
=
(
x_node
,)
graph
.
erase_node
(
narrow_node
)
return
gm
def
test_size_value_converting_pass
():
model
=
TestModule
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
8
).
to
(
'meta'
)}
input
=
torch
.
rand
(
4
,
8
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
x_node
=
list
(
graph
.
nodes
)[
0
]
x_sharding_spec
=
ShardingSpec
(
device_mesh
,
entire_shape
=
(
4
,
8
),
dim_partition_dict
=
{
0
:
[
0
]})
setattr
(
x_node
,
'sharding_spec'
,
x_sharding_spec
)
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
insert_narrow
(
gm
,
x_node
)
gm
.
recompile
()
size
=
gm
(
input
)
assert
size
==
torch
.
Size
([
2
,
8
])
narrow_node
=
list
(
gm
.
graph
.
nodes
)[
1
]
gm
=
recover_narrow
(
gm
,
narrow_node
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
insert_narrow
(
gm
,
x_node
)
gm
.
recompile
()
size
=
gm
(
input
)
assert
size
==
torch
.
Size
([
4
,
8
])
if
__name__
==
'__main__'
:
test_size_value_converting_pass
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
cb2c6a24
from
faulthandler
import
disable
from
functools
import
partial
from
xml.dom
import
WrongDocumentErr
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
typing_extensions
import
Self
from
colossalai.auto_parallel.tensor_shard.node_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment