Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
762
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1546 additions
and
74 deletions
+1546
-74
colossalai/auto_parallel/passes/__init__.py
colossalai/auto_parallel/passes/__init__.py
+0
-0
colossalai/auto_parallel/passes/comm_metainfo_pass.py
colossalai/auto_parallel/passes/comm_metainfo_pass.py
+113
-0
colossalai/auto_parallel/passes/constants.py
colossalai/auto_parallel/passes/constants.py
+8
-0
colossalai/auto_parallel/passes/meta_info_prop.py
colossalai/auto_parallel/passes/meta_info_prop.py
+165
-0
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+221
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+471
-0
colossalai/auto_parallel/tensor_shard/constants.py
colossalai/auto_parallel/tensor_shard/constants.py
+11
-3
colossalai/auto_parallel/tensor_shard/deprecated/__init__.py
colossalai/auto_parallel/tensor_shard/deprecated/__init__.py
+3
-3
colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
+4
-3
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py
...rallel/tensor_shard/deprecated/op_handler/conv_handler.py
+9
-9
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py
...arallel/tensor_shard/deprecated/op_handler/dot_handler.py
+22
-22
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py
.../tensor_shard/deprecated/op_handler/layer_norm_handler.py
+13
-9
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py
...el/tensor_shard/deprecated/op_handler/operator_handler.py
+7
-5
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py
...lel/tensor_shard/deprecated/op_handler/reshape_handler.py
+3
-3
colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py
...arallel/tensor_shard/deprecated/strategies_constructor.py
+13
-10
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+275
-0
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+16
-4
colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
.../auto_parallel/tensor_shard/node_handler/addmm_handler.py
+91
-0
colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
..._parallel/tensor_shard/node_handler/batch_norm_handler.py
+5
-3
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
...l/tensor_shard/node_handler/binary_elementwise_handler.py
+96
-0
No files found.
colossalai/auto_parallel/passes/__init__.py
0 → 100644
View file @
e532679c
colossalai/auto_parallel/passes/comm_metainfo_pass.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
import
torch
from
torch.fx
import
GraphModule
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler
import
MetaInfo
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply
,
runtime_comm_spec_apply
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
colossalai.tensor.comm_spec
import
CommSpec
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
def
_construct_meta_info
(
node
:
Node
,
origin_sharding_spec
:
ShardingSpec
,
target_sharding_spec
:
ShardingSpec
)
->
MetaInfo
:
# get comm_action_sequence and total_cost from shape_consistency_manager
_
,
comm_action_sequence
,
total_cost
=
shape_consistency_manager
.
shape_consistency
(
origin_sharding_spec
,
target_sharding_spec
)
meta_info
=
MetaInfo
()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for MetaInfo
mem_cost
=
shape_consistency_manager
.
mem_cost
(
comm_action_sequence
)
# extract user that has _meta_data and extract element length
input_node
=
next
(
n
for
n
in
node
.
_input_nodes
if
hasattr
(
n
,
'_meta_data'
))
element_length
=
input_node
.
_meta_data
.
element_size
()
mem_cost
.
fwd
.
activation
*=
element_length
mem_cost
.
fwd
.
temp
*=
element_length
mem_cost
.
bwd
.
activation
*=
element_length
mem_cost
.
bwd
.
temp
*=
element_length
mem_cost
.
total
.
activation
*=
element_length
meta_info
.
memory_cost
=
mem_cost
# get computation cost for MetaInfo
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
# get tensor shape for MetaInfo
origin_sharding_spec
:
ShardingSpec
target_sharding_spec
:
ShardingSpec
input_shape
=
origin_sharding_spec
.
get_sharded_shape_per_device
()
output_shape
=
target_sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'meta'
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'meta'
)]
return
meta_info
def
_runtime_apply_meta_info
(
node
:
Node
,
origin_spec_dict
,
sharding_spec_dict
)
->
MetaInfo
:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
# extract node index and user node index
args
=
node
.
args
node_index
,
user_node_index
=
args
[
3
],
args
[
4
]
origin_sharding_spec
,
target_sharding_spec
=
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
]
return
_construct_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
_runtime_comm_spec_apply_meta_info
(
node
:
Node
,
comm_actions_dict
:
Dict
)
->
MetaInfo
:
# extract node_index and op_data_name
node_index
,
op_data_name
=
node
.
args
[
2
],
node
.
args
[
3
]
comm_action
=
comm_actions_dict
[
node_index
][
op_data_name
]
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
# this case is for all_reduce, there will be no memory cost
meta_info
=
MetaInfo
()
meta_info
.
memory_cost
=
TrainCycleItem
(
MemoryCost
(),
MemoryCost
(),
MemoryCost
)
output_node
=
next
(
n
for
n
in
node
.
users
if
hasattr
(
n
,
'_meta_data'
))
element_length
=
output_node
.
_meta_data
.
element_size
()
total_cost
=
comm_action
.
comm_spec
.
get_comm_cost
()
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
input_shape
=
output_shape
=
comm_action
.
comm_spec
.
sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'meta'
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'meta'
)]
else
:
# this case will be handled by shape consistency manager
origin_sharding_spec
,
target_sharding_spec
=
comm_action
.
comm_spec
[
'src_spec'
],
comm_action
.
comm_spec
[
'tgt_spec'
]
meta_info
=
_construct_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
meta_info
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
)
->
GraphModule
:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for
node
in
gm
.
graph
.
nodes
:
if
node
.
target
==
runtime_apply
:
setattr
(
node
,
'best_metainfo'
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
elif
node
.
target
==
runtime_comm_spec_apply
:
setattr
(
node
,
'best_metainfo'
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
else
:
pass
return
gm
colossalai/auto_parallel/passes/constants.py
0 → 100644
View file @
e532679c
import
torch
OUTPUT_SAVED_OPS
=
[
torch
.
nn
.
functional
.
relu
,
torch
.
nn
.
functional
.
softmax
,
torch
.
flatten
]
OUTPUT_SAVED_MOD
=
[
torch
.
nn
.
ReLU
,
torch
.
nn
.
Softmax
,
]
colossalai/auto_parallel/passes/meta_info_prop.py
0 → 100644
View file @
e532679c
import
uuid
from
dataclasses
import
asdict
from
typing
import
List
import
torch
import
torch.fx
from
torch.fx
import
GraphModule
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler
import
MetaInfo
from
colossalai.auto_parallel.passes.constants
import
OUTPUT_SAVED_MOD
,
OUTPUT_SAVED_OPS
from
colossalai.fx._compatibility
import
compatibility
from
colossalai.fx.profiler
import
GraphInfo
def
_normalize_tuple
(
x
):
if
not
isinstance
(
x
,
tuple
):
return
(
x
,)
return
x
@
compatibility
(
is_backward_compatible
=
False
)
class
MetaInfoProp
:
def
__init__
(
self
,
module
:
GraphModule
)
->
None
:
self
.
module
=
module
self
.
func_dict
=
{
'placeholder'
:
self
.
placeholder_handler
,
'get_attr'
:
self
.
get_attr_handler
,
'output'
:
self
.
output_handler
,
'call_function'
:
self
.
node_handler
,
'call_module'
:
self
.
node_handler
,
'call_method'
:
self
.
node_handler
,
}
def
_set_data_ptr
(
self
,
x
):
"""
Set uuid to tensor
"""
if
isinstance
(
x
,
torch
.
Tensor
):
if
not
x
.
data_ptr
():
data_ptr
=
uuid
.
uuid4
()
x
.
data_ptr
=
lambda
:
data_ptr
def
_is_inplace
(
self
,
node
:
Node
):
"""
Check if the node is inplace operation.
"""
if
node
.
op
==
'call_module'
:
return
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
).
__class__
in
OUTPUT_SAVED_MOD
elif
node
.
op
==
"call_function"
:
return
node
.
target
in
OUTPUT_SAVED_OPS
return
False
def
run
(
self
)
->
GraphModule
:
"""
Run the meta information propagation pass on the module.
"""
for
node
in
self
.
module
.
graph
.
nodes
:
node
:
Node
self
.
func_dict
[
node
.
op
](
node
)
@
compatibility
(
is_backward_compatible
=
False
)
def
placeholder_handler
(
self
,
node
:
Node
)
->
None
:
"""
Handle the placeholder node.
"""
graph_info
=
GraphInfo
()
out
=
_normalize_tuple
(
getattr
(
node
,
'_meta_data'
,
None
))
graph_info
.
fwd_out
=
list
(
out
)
if
out
[
0
]
is
not
None
else
[]
node
.
meta
=
{
**
asdict
(
graph_info
)}
@
compatibility
(
is_backward_compatible
=
False
)
def
get_attr_handler
(
self
,
node
:
Node
)
->
None
:
"""
Handle the get_attr node.
"""
graph_info
=
GraphInfo
()
node
.
meta
=
{
**
asdict
(
graph_info
)}
@
compatibility
(
is_backward_compatible
=
False
)
def
output_handler
(
self
,
node
:
Node
)
->
None
:
"""
Handle the output node.
"""
graph_info
=
GraphInfo
()
output_tensors
=
[]
for
par
in
node
.
_input_nodes
:
if
par
.
meta
:
output_tensors
+=
par
.
meta
[
"fwd_out"
]
graph_info
.
fwd_in
=
output_tensors
node
.
meta
=
{
**
asdict
(
graph_info
)}
@
compatibility
(
is_backward_compatible
=
False
)
def
node_handler
(
self
,
node
:
Node
)
->
None
:
"""
Handle other kind of nodes
"""
assert
hasattr
(
node
,
'best_metainfo'
),
f
"Cannot find best_metainfo in node
{
node
}
,
{
node
.
op
}
"
graph_info
=
GraphInfo
()
meta_info
=
node
.
best_metainfo
meta_info
:
MetaInfo
# set data_ptr for input_tensor in MetaInfo class
input_tensors
:
List
[
torch
.
Tensor
]
=
meta_info
.
fwd_in
buffer_tensors
:
List
[
torch
.
Tensor
]
=
meta_info
.
fwd_buffer
output_tensors
:
List
[
torch
.
Tensor
]
=
meta_info
.
fwd_out
if
self
.
_is_inplace
(
node
):
# inplace operation will not create new tensor, and it only has one parent node
# TODO: Verify this observation
# set data_ptr for input_tensor, buffer_tensor and output_tensor of current node
parent_node
=
list
(
node
.
_input_nodes
.
keys
())[
0
]
parent_tensor
=
parent_node
.
meta
.
get
(
"fwd_out"
)[
0
]
parent_tensor
:
torch
.
Tensor
for
tensor
in
input_tensors
:
tensor
.
data_ptr
=
parent_tensor
.
data_ptr
for
tensor
in
buffer_tensors
:
tensor
.
data_ptr
=
parent_tensor
.
data_ptr
for
tensor
in
output_tensors
:
tensor
.
data_ptr
=
parent_tensor
.
data_ptr
else
:
for
par
in
node
.
_input_nodes
:
# set data_ptr for the input_tensor of current node from the output_tensor of its parent node
for
tensor
in
par
.
meta
.
get
(
"fwd_out"
,
[]):
tensor
:
torch
.
Tensor
target_input_tensor
=
next
(
(
x
for
x
in
input_tensors
if
not
x
.
data_ptr
()
and
x
.
shape
==
tensor
.
shape
),
None
)
if
target_input_tensor
is
not
None
:
target_input_tensor
.
data_ptr
=
tensor
.
data_ptr
# set data_ptr for tensor in input_tensor that is not set
for
tensor
in
input_tensors
:
if
not
tensor
.
data_ptr
():
self
.
_set_data_ptr
(
tensor
)
# set data_ptr for buffer_tensor
for
tensor
in
buffer_tensors
:
self
.
_set_data_ptr
(
tensor
)
# set data_ptr for output_tensor
for
tensor
in
output_tensors
:
self
.
_set_data_ptr
(
tensor
)
# attach them to graph_info
graph_info
.
fwd_in
=
input_tensors
graph_info
.
fwd_tmp
=
buffer_tensors
graph_info
.
fwd_out
=
output_tensors
# fetch other memory informations
memory_cost
=
meta_info
.
memory_cost
graph_info
.
fwd_mem_tmp
=
memory_cost
.
fwd
.
temp
graph_info
.
fwd_mem_out
=
memory_cost
.
fwd
.
activation
graph_info
.
bwd_mem_tmp
=
memory_cost
.
bwd
.
temp
graph_info
.
bwd_mem_out
=
memory_cost
.
bwd
.
activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost
=
meta_info
.
compute_cost
graph_info
.
fwd_time
=
compute_cost
.
fwd
graph_info
.
bwd_time
=
compute_cost
.
bwd
node
.
meta
=
{
**
asdict
(
graph_info
)}
colossalai/
fx/passes/experimental/adding_shape_consistenc
y_pass
_v2
.py
→
colossalai/
auto_parallel/passes/runtime_appl
y_pass.py
View file @
e532679c
import
builtins
import
copy
import
operator
from
ast
import
NodeTransformer
from
copy
import
deepcopy
from
typing
import
List
from
typing
import
Dict
,
List
import
torch
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommAction
,
CommType
,
OperationDataType
from
colossalai.auto_parallel.meta_profiler
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
TrainCycleItem
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.passes.split_module
import
split_module
from
colossalai.tensor.comm_spec
import
CollectiveCommPattern
,
CommSpec
,
_all_reduce
,
pattern_to_func_dict
from
colossalai.tensor.comm_spec
import
CommSpec
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
def
runtime_apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
def
runtime_apply
(
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
user_node_index
:
int
):
"""
This method will be invoked during runtime to do the shape consistency, which make sure the activations is converted into
the user node expected form.
"""
origin_sharding_spec
=
origin_dict
[
node_index
]
target_sharding_spec
=
input_dict
[
node_index
][
user_node_index
]
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
runtime_comm_spec_apply
(
tensor
,
comm_actions_dict
,
node_index
,
op_data
):
def
runtime_apply_for_iterable_object
(
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
user_node_index
:
int
):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst
=
[]
for
index
,
(
origin_sharding_spec
,
target_sharding_spec
)
in
enumerate
(
zip
(
origin_dict
[
node_index
],
input_dict
[
node_index
][
user_node_index
])):
rst
.
append
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
[
index
],
origin_sharding_spec
,
target_sharding_spec
))
rst
=
type
(
node
)(
rst
)
return
rst
comm_action
=
comm_actions_dict
[
node_index
][
op_data
]
def
runtime_comm_spec_apply
(
tensor
:
torch
.
Tensor
,
comm_actions_dict
:
Dict
,
node_index
:
int
,
op_data_name
:
str
):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
"""
comm_action
=
comm_actions_dict
[
node_index
][
op_data_name
]
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
rst
=
comm_action
.
comm_spec
.
covert_spec_to_action
(
tensor
)
else
:
...
...
@@ -37,94 +61,11 @@ def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
return
rst
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
):
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict
=
{}
for
node_index
,
(
node
,
strategy_index
)
in
enumerate
(
zip
(
nodes
,
solution
)):
strategies_vector
=
node
.
strategies_vector
setattr
(
node
,
'best_strategy'
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
'sharding_spec'
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
# apply the sharding spec of parameters
for
node
in
nodes
:
if
node
.
op
==
'call_module'
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
for
name
,
param
in
target_module
.
named_parameters
():
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
param_sharded
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
else
:
param_sharded
=
param
setattr
(
target_module
,
name
,
param_sharded
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param_sharded
,
comm_spec_to_use
)
sharded_buffer_dict
=
{}
for
name
,
buffer
in
target_module
.
named_buffers
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
setattr
(
buffer
,
'sharding_spec'
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
buffer_sharded
=
shape_consistency_manager
.
apply
(
buffer
,
target_sharding_spec
)
sharded_buffer_dict
[
name
]
=
buffer_sharded
for
name
,
buffer_sharded
in
sharded_buffer_dict
.
items
():
setattr
(
target_module
,
name
,
buffer_sharded
.
detach
().
clone
())
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
target_sharding_specs
=
[]
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
# the dict to record comm actions of nodes
comm_actions_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
comm_action_dict
=
{}
for
op_data
,
comm_action
in
node
.
best_strategy
.
communication_actions
.
items
():
comm_action_dict
[
op_data
.
name
]
=
comm_action
comm_actions_dict
[
index
]
=
comm_action_dict
# add above dicts into graph
for
node
in
nodes
:
if
node
.
op
!=
'placeholder'
:
with
mod_graph
.
inserting_before
(
node
):
input_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'sharding_spec_convert_dict'
)
origin_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'origin_node_sharding_spec_dict'
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'comm_actions_dict'
)
break
return
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
def
shape_consistency_pass
(
gm
:
torch
.
fx
.
GraphModule
):
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
input_dict_node
=
None
origin_dict_node
=
None
def
_preprocess_graph
(
nodes
:
List
[
Node
]):
"""
This method is used to extract all the placeholders with sharding information,
and mapping the nodes into the index of the origin graph.
"""
# mapping the node into the origin graph index
node_to_index_dict
=
{}
index
=
0
...
...
@@ -142,40 +83,110 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
continue
node_to_index_dict
[
node
]
=
index
index
+=
1
assert
input_dict_node
is
not
None
# add shape consistency apply function into graph
return
input_dict_node
,
origin_dict_node
,
comm_actions_dict_node
,
node_to_index_dict
def
_shape_consistency_apply
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the shape consistency node to the origin graph.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
input_dict_node
,
origin_dict_node
,
_
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
,
'best_strategy'
)
or
node
.
op
==
'output'
:
continue
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
user_node_index
=
user_node
.
strategies_vector
.
predecessor_nodes
.
index
(
node
)
for
user_node_index
,
user_node
in
enumerate
(
node
.
strategies_vector
.
successor_nodes
):
if
isinstance
(
node
.
sharding_spec
,
(
list
,
tuple
)):
assert
isinstance
(
node
.
target_sharding_specs
,
(
list
,
tuple
)),
'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference
=
0
for
sharding_spec
,
target_sharding_spec
in
zip
(
node
.
sharding_spec
,
node
.
target_sharding_specs
[
user_node_index
]):
total_difference
+=
sharding_spec
.
sharding_sequence_difference
(
target_sharding_spec
)
if
total_difference
==
0
:
continue
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_apply_for_iterable_object
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
else
:
assert
isinstance
(
node
.
sharding_spec
,
ShardingSpec
),
'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if
node
.
sharding_spec
.
sharding_sequence_difference
(
node
.
target_sharding_specs
[
user_node_index
])
==
0
:
continue
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
origin_index_args
=
user_node
.
args
.
index
(
node
)
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
if
node
in
new_args
:
# substitute the origin node with shape_consistency_node
origin_index_args
=
new_args
.
index
(
node
)
new_args
[
origin_index_args
]
=
shape_consistency_node
user_node
.
args
=
new_args
user_node
.
args
=
tuple
(
new_args
)
elif
str
(
node
)
in
new_kwargs
:
# substitute the origin node with shape_consistency_node
new_kwargs
[
str
(
node
)]
=
shape_consistency_node
user_node
.
kwargs
=
new_kwargs
return
gm
def
_comm_spec_apply
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the comm spec apply node to the origin graph.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
_
,
_
,
comm_actions_dict_node
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
,
'best_strategy'
)
or
node
.
op
==
'output'
:
continue
comm_actions
=
node
.
best_strategy
.
communication_actions
for
op_data
,
comm_action
in
comm_actions
.
items
():
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
if
op_data
.
type
==
OperationDataType
.
PARAM
:
if
comm_action
.
comm_type
==
CommType
.
HOOK
:
continue
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
if
op_data
.
type
==
OperationDataType
.
OUTPUT
:
comm_object
=
node
elif
comm_action
.
key_for_kwarg
is
not
None
:
comm_object
=
node
.
kwargs
[
comm_action
.
key_for_kwarg
]
else
:
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
with
mod_graph
.
inserting_before
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_comm_spec_apply
,
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
))
# the origin node may be a positional argument or key word argument of user node
if
comm_action
.
key_for_kwarg
is
not
None
:
# substitute the origin node with comm_spec_apply_node
new_kwargs
=
dict
(
node
.
kwargs
)
new_kwargs
[
comm_action
.
key_for_kwarg
]
=
comm_spec_apply_node
node
.
kwargs
=
new_kwargs
else
:
# substitute the origin node with comm_spec_apply_node
new_args
=
list
(
node
.
args
)
new_args
[
comm_action
.
arg_index
]
=
comm_spec_apply_node
node
.
args
=
new_args
node
.
args
=
tuple
(
new_args
)
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
with
mod_graph
.
inserting_after
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
...
...
@@ -187,7 +198,24 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
if
user
==
comm_spec_apply_node
:
continue
new_args
=
list
(
user
.
args
)
new_kwargs
=
dict
(
user
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
if
node
in
new_args
:
# substitute the origin node with comm_spec_apply_node
new_args
[
new_args
.
index
(
node
)]
=
comm_spec_apply_node
user
.
args
=
tuple
(
new_args
)
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
elif
str
(
node
)
in
new_kwargs
:
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
return
gm
def
runtime_apply_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
The method manages all the passes acting on the distributed training runtime.
"""
gm
=
_shape_consistency_apply
(
gm
)
gm
=
_comm_spec_apply
(
gm
)
return
gm
colossalai/auto_parallel/passes/runtime_preparation_pass.py
0 → 100644
View file @
e532679c
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/tensor_shard/constants.py
View file @
e532679c
import
torch
import
operator
import
torch
__all__
=
[
'ELEMENTWISE_MODULE_OP'
,
'ELEMENTWISE_FUNC_OP'
,
'RESHAPE_FUNC_OP'
,
'CONV_MODULE_OP'
,
'CONV_FUNC_OP'
,
'LINEAR_MODULE_OP'
,
'LINEAR_FUNC_OP'
,
'BATCHNORM_MODULE_OP'
,
'POOL_MODULE_OP'
,
'NON_PARAM_FUNC_OP'
,
'BCAST_FUNC_OP'
,
...
...
@@ -25,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
# TODO: contiguous maybe need some extra processes.
torch
.
Tensor
.
contiguous
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
reshape
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
reshape
,
torch
.
transpose
,
torch
.
split
,
torch
.
permute
,
operator
.
getitem
,
]
RESHAPE_METHOD_OP
=
[
torch
.
Tensor
.
view
,
torch
.
Tensor
.
unsqueeze
,
...
...
@@ -35,7 +43,7 @@ RESHAPE_METHOD_OP = [
]
BCAST_FUNC_OP
=
[
torch
.
add
,
torch
.
sub
,
torch
.
mul
,
torch
.
div
,
torch
.
floor_divide
,
torch
.
true_divide
,
operator
.
add
,
operator
.
sub
,
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
torch
.
matmul
,
torch
.
where
,
operator
.
pow
,
torch
.
pow
,
torch
.
tanh
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
torch
.
matmul
,
operator
.
pow
,
torch
.
pow
]
CONV_MODULE_OP
=
[
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose1d
,
torch
.
nn
.
ConvTranspose2d
,
...
...
colossalai/auto_parallel/tensor_shard/deprecated/__init__.py
View file @
e532679c
from
.cost_graph
import
CostGraph
from
.graph_analysis
import
GraphAnalyser
from
.options
import
SolverOptions
from
.strategies_constructor
import
StrategiesConstructor
from
.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.cost_graph
import
CostGraph
from
.solver
import
Solver
from
.graph_analysis
import
GraphAnalyser
\ No newline at end of file
from
.strategies_constructor
import
StrategiesConstructor
colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
View file @
e532679c
...
...
@@ -5,10 +5,11 @@ from functools import reduce
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.fx.node
import
Node
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx.node
import
Node
from
.constants
import
INFINITY_COST
...
...
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py
View file @
e532679c
...
...
@@ -3,9 +3,9 @@ import warnings
from
functools
import
reduce
import
torch
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
\
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
(
ShardingStrategy
,
StrategiesVector
)
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
...
...
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py
View file @
e532679c
...
...
@@ -6,9 +6,9 @@ from typing import List
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
\
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
(
ShardingStrategy
,
StrategiesVector
)
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
..constants
import
LINEAR_FUNC_OP
,
LINEAR_MODULE_OP
from
.operator_handler
import
OperatorHandler
...
...
@@ -431,7 +431,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -451,7 +451,7 @@ class DotHandler(OperatorHandler):
# create and register strategy
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -473,7 +473,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -491,7 +491,7 @@ class DotHandler(OperatorHandler):
communication_cost_grad_backward
=
self
.
device_mesh
.
all_reduce_cost
(
weight_memory_cost
,
mesh_dim_0
)
communication_cost
=
communication_cost_activation_forward
+
communication_cost_grad_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -510,7 +510,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -529,7 +529,7 @@ class DotHandler(OperatorHandler):
communication_cost
=
communication_cost_activation_backward
+
communication_cost_activation_forward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -548,7 +548,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -564,7 +564,7 @@ class DotHandler(OperatorHandler):
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
activation_memory_cost
,
mesh_dim
)
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -583,7 +583,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -600,7 +600,7 @@ class DotHandler(OperatorHandler):
communication_cost_activation_backward
=
self
.
device_mesh
.
all_reduce_cost
(
input_grad_memory_cost
,
mesh_dim
)
communication_cost
=
communication_cost_activation_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -619,7 +619,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -636,7 +636,7 @@ class DotHandler(OperatorHandler):
communication_cost_weight_backward
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
weight_memory_cost
,
0
)
communication_cost
=
communication_cost_weight_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -655,7 +655,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -673,7 +673,7 @@ class DotHandler(OperatorHandler):
activation_memory_cost
,
0
)
communication_cost
=
communication_cost_forward_activation
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
@@ -692,7 +692,7 @@ class DotHandler(OperatorHandler):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ou
t
put
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
...
...
@@ -709,7 +709,7 @@ class DotHandler(OperatorHandler):
input_grad_memory_cost
,
0
)
communication_cost
=
communication_cost_activation_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
output_sharding_spec
=
sharding_spec_for_ou
t
put
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
toatl_memory_cost
,
...
...
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py
View file @
e532679c
...
...
@@ -2,10 +2,14 @@ import operator
from
functools
import
reduce
import
torch
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
(
enumerate_all_possible_1d_sharding
,
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
generate_sharding_size
,
ignore_sharding_exception
)
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
(
ShardingStrategy
,
StrategiesVector
)
generate_sharding_size
,
ignore_sharding_exception
,
)
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
...
...
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py
View file @
e532679c
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
from
webbrowser
import
Opera
import
torch
import
torch.nn
as
nn
from
abc
import
ABC
,
abstractmethod
from
torch.fx.node
import
Node
from
typing
import
Dict
,
List
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
..sharding_strategy
import
StrategiesVector
__all__
=
[
'OperatorHandler'
]
...
...
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py
View file @
e532679c
...
...
@@ -4,9 +4,9 @@ import warnings
from
copy
import
deepcopy
import
torch
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
\
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
(
ShardingStrategy
,
StrategiesVector
)
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
...
colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py
View file @
e532679c
import
builtins
import
math
import
operator
from
copy
import
deepcopy
from
typing
import
Dict
,
List
import
torch
from
torch.fx
import
Graph
,
Node
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
.constants
import
*
from
.op_handler
import
*
from
.options
import
SolverOptions
from
.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.op_handler
import
*
from
.constants
import
*
from
copy
import
deepcopy
import
math
import
torch
import
operator
from
typing
import
Dict
,
List
from
._utils
import
generate_sharding_spec
,
generate_resharding_costs
import
builtins
__all__
=
[
'StrategiesConstructor'
]
...
...
colossalai/auto_parallel/tensor_shard/initialize.py
0 → 100644
View file @
e532679c
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
e532679c
from
.addmm_handler
import
ADDMMFunctionHandler
from
.batch_norm_handler
import
BatchNormModuleHandler
from
.binary_elementwise_handler
import
BinaryElementwiseHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.embedding_handler
import
EmbeddingFunctionHandler
,
EmbeddingModuleHandler
from
.experimental
import
PermuteHandler
,
ViewHandler
from
.getattr_handler
import
GetattrHandler
from
.getitem_handler
import
GetItemHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.output_handler
import
OuputHandler
from
.placeholder_handler
import
Placeho
d
lerHandler
from
.output_handler
import
Ou
t
putHandler
from
.placeholder_handler
import
Placehol
d
erHandler
from
.registry
import
operator_registry
from
.reshape_handler
import
ReshapeHandler
from
.softmax_handler
import
SoftmaxHandler
from
.sum_handler
import
SumHandler
from
.tensor_constructor_handler
import
TensorConstructorHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.where_handler
import
WhereHandler
__all__
=
[
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'operator_registry'
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlaceholderHandler'
,
'OutputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
]
colossalai/auto_parallel/tensor_shard/node_handler/addmm_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
,
Union
import
torch
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
,
ShapeConsistencyManager
from
..sharding_strategy
import
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
comm_actions_for_oprands
,
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator
__all__
=
[
'ADDMMFunctionHandler'
]
@
operator_registry
.
register
(
torch
.
addmm
)
@
operator_registry
.
register
(
torch
.
Tensor
.
addmm
)
class
ADDMMFunctionHandler
(
NodeHandler
):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def
_infer_op_data_type
(
self
,
tensor
:
torch
.
Tensor
)
->
OperationDataType
:
if
isinstance
(
tensor
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
return
data_type
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# input operand
input_data
=
self
.
node
.
args
[
1
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
self
.
_infer_op_data_type
(
input_data
),
data
=
input_data
)
# other operand
other_data
=
self
.
node
.
args
[
2
].
_meta_data
physical_other_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
2
]),
type
=
self
.
_infer_op_data_type
(
other_data
),
data
=
other_data
)
# bias physical shape
bias_logical_shape
=
self
.
node
.
_meta_data
.
shape
bias_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
self
.
_infer_op_data_type
(
bias_data
),
data
=
bias_data
,
logical_shape
=
bias_logical_shape
)
# output
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
,
'bias'
:
physical_bias_operand
}
return
mapping
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
,
linear_projection_type
=
'addmm'
))
return
generators
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping
=
self
.
get_operation_data_mapping
()
bias_op_data
=
op_data_mapping
[
'bias'
]
bias_physical_shape
=
bias_op_data
.
data
.
shape
bias_logical_shape
=
bias_op_data
.
logical_shape
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
bias_op_data
.
name
)
bias_sharding_spec
,
removed_dims
=
recover_sharding_spec_for_broadcast_shape
(
bias_sharding_spec
,
bias_logical_shape
,
bias_physical_shape
)
strategy
.
sharding_specs
[
bias_op_data
]
=
bias_sharding_spec
if
len
(
removed_dims
)
>
0
:
comm_action
=
comm_actions_for_oprands
(
node
=
self
.
node
,
removed_dims
=
removed_dims
,
op_data
=
bias_op_data
,
sharding_spec
=
bias_sharding_spec
)
strategy
.
communication_actions
[
bias_op_data
]
=
comm_action
return
strategy
colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
View file @
e532679c
...
...
@@ -2,8 +2,10 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
ModuleHandler
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
.node_handler
import
MetaInfoModuleHandler
,
ModuleHandler
from
.registry
import
operator_registry
from
.strategy
import
BatchNormStrategyGenerator
,
StrategyGenerator
...
...
@@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm3d
)
class
BatchNormModuleHandler
(
ModuleHandler
):
class
BatchNormModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
0 → 100644
View file @
e532679c
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment