Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1cce6e36
Unverified
Commit
1cce6e36
authored
Dec 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 20, 2022
Browse files
[autoparallel] use metainfo in handler (#2149)
parent
9b39170a
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
105 additions
and
31 deletions
+105
-31
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
...i/auto_parallel/meta_profiler/meta_registry/activation.py
+1
-1
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
+5
-2
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
...salai/auto_parallel/meta_profiler/meta_registry/linear.py
+7
-3
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+1
-1
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
...alai/auto_parallel/meta_profiler/meta_registry/pooling.py
+1
-1
colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
..._parallel/tensor_shard/node_handler/batch_norm_handler.py
+5
-3
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
...l/tensor_shard/node_handler/binary_elementwise_handler.py
+3
-9
colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
...i/auto_parallel/tensor_shard/node_handler/conv_handler.py
+4
-4
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
...auto_parallel/tensor_shard/node_handler/linear_handler.py
+9
-5
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+67
-0
colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
...allel/tensor_shard/node_handler/normal_pooling_handler.py
+2
-2
No files found.
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
View file @
1cce6e36
...
...
@@ -28,7 +28,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
))
.
data
input_tensor
=
args
[
0
]
.
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
inplace
=
kwargs
.
get
(
"inplace"
,
False
)
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/conv.py
View file @
1cce6e36
...
...
@@ -58,9 +58,12 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
"""
has_bias
:
bool
=
False
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
))
.
data
input_tensor
=
args
[
0
]
.
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensors
=
[
x
.
data
for
x
in
args
if
x
.
type
==
OperationDataType
.
PARAM
]
if
len
(
args
)
==
4
:
weight_tensors
=
[
args
[
1
].
data
,
args
[
3
].
data
]
else
:
weight_tensors
=
[
args
[
1
].
data
]
# check if conv has bias
if
len
(
weight_tensors
)
>
1
:
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
View file @
1cce6e36
...
...
@@ -66,9 +66,13 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
"""
has_bias
:
bool
=
False
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
)).
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensors
=
[
x
.
data
for
x
in
args
if
x
.
type
==
OperationDataType
.
PARAM
]
input_tensor
=
args
[
0
].
data
output_tensor
=
args
[
2
].
data
if
len
(
args
)
==
4
:
weight_tensors
=
[
args
[
1
].
data
,
args
[
3
].
data
]
else
:
weight_tensors
=
[
args
[
1
].
data
]
# process the dimension of input and output
if
len
(
input_tensor
.
shape
)
>
2
:
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
View file @
1cce6e36
...
...
@@ -45,7 +45,7 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
))
.
data
input_tensor
=
args
[
0
]
.
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"weight"
,
args
)).
data
bias_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"bias"
,
args
)).
data
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
View file @
1cce6e36
...
...
@@ -30,7 +30,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
ARG
,
args
))
.
data
input_tensor
=
args
[
0
]
.
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
# construct forward args for flop mapping
...
...
colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py
View file @
1cce6e36
...
...
@@ -2,8 +2,10 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
ModuleHandler
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
.node_handler
import
MetaInfoModuleHandler
,
ModuleHandler
from
.registry
import
operator_registry
from
.strategy
import
BatchNormStrategyGenerator
,
StrategyGenerator
...
...
@@ -13,7 +15,7 @@ __all__ = ['BatchNormModuleHandler']
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm3d
)
class
BatchNormModuleHandler
(
ModuleHandler
):
class
BatchNormModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
View file @
1cce6e36
...
...
@@ -3,18 +3,12 @@ from typing import Dict, List, Union
import
torch
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
,
ShapeConsistencyManager
from
..constants
import
BCAST_FUNC_OP
from
..utils
import
comm_actions_for_oprands
,
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.node_handler
import
MetaInfoNodeHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
BinaryElementwiseStrategyGenerator
,
StrategyGenerator
...
...
@@ -22,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@
operator_registry
.
register
(
BCAST_FUNC_OP
)
class
BinaryElementwiseHandler
(
NodeHandler
):
class
BinaryElementwiseHandler
(
MetaInfo
NodeHandler
):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
View file @
1cce6e36
...
...
@@ -3,9 +3,9 @@ from typing import Dict, List
import
torch
import
torch.nn.functional
as
F
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
from
..utils
import
transpose_partition_dim
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
.node_handler
import
MetaInfoModuleHandler
,
MetaInfoNodeHandler
,
ModuleHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
ConvStrategyGenerator
,
StrategyGenerator
...
...
@@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@
operator_registry
.
register
(
torch
.
nn
.
Conv1d
)
@
operator_registry
.
register
(
torch
.
nn
.
Conv2d
)
@
operator_registry
.
register
(
torch
.
nn
.
Conv3d
)
class
ConvModuleHandler
(
ModuleHandler
):
class
ConvModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
...
...
@@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
@
operator_registry
.
register
(
F
.
conv1d
)
@
operator_registry
.
register
(
F
.
conv2d
)
@
operator_registry
.
register
(
F
.
conv3d
)
class
ConvFunctionHandler
(
NodeHandler
):
class
ConvFunctionHandler
(
MetaInfo
NodeHandler
):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
View file @
1cce6e36
...
...
@@ -3,12 +3,16 @@ from typing import Dict, List, Union
import
torch
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.tensor_shard.utils
import
transpose_partition_dim
,
update_partition_dim
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_sharding_spec_validity
,
transpose_partition_dim
,
update_partition_dim
,
)
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.sharding_spec
import
ShardingNotDivisibleError
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
from
.node_handler
import
MetaInfoModuleHandler
,
MetaInfoNodeHandler
,
ModuleHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator
...
...
@@ -139,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
@
operator_registry
.
register
(
torch
.
nn
.
Linear
)
class
LinearModuleHandler
(
ModuleHandler
):
class
LinearModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
...
...
@@ -199,7 +203,7 @@ class LinearModuleHandler(ModuleHandler):
@
operator_registry
.
register
(
F
.
linear
)
class
LinearFunctionHandler
(
NodeHandler
):
class
LinearFunctionHandler
(
MetaInfo
NodeHandler
):
"""
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
1cce6e36
...
...
@@ -4,6 +4,7 @@ from typing import Dict, List, Tuple, Union
import
torch
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
...
...
@@ -133,6 +134,26 @@ class NodeHandler(ABC):
strategy
.
resharding_costs
=
resharding_costs
return
strategy
def
get_target_function
(
self
)
->
callable
:
"""
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if
self
.
node
.
op
in
(
'placeholder'
,
'get_attr'
,
'output'
):
return
None
if
self
.
node
.
op
==
'call_module'
:
submod
=
self
.
node
.
graph
.
owning_module
.
get_submodule
(
self
.
node
.
target
)
target
=
type
(
submod
)
elif
self
.
node
.
op
==
'call_function'
:
target
=
self
.
node
.
target
elif
self
.
node
.
op
==
'call_method'
:
target
=
getattr
(
self
.
node
.
args
[
0
].
_meta_data
.
__class__
,
self
.
node
.
target
)
else
:
raise
ValueError
(
f
'Unsupported node type:
{
self
.
node
.
op
}
'
)
return
target
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
Register different sharding strategies for the current node.
...
...
@@ -204,6 +225,29 @@ class NodeHandler(ABC):
pass
class
MetaInfoNodeHandler
(
NodeHandler
):
"""
This is a base class to handle the nodes patched in the meta profiler.
Note: this class will be integrated into the NodeHandler class in the future, after
all the functions are patched.
"""
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super
().
register_strategy
(
compute_resharding_cost
=
compute_resharding_cost
)
target
=
self
.
get_target_function
()
for
strategy
in
self
.
strategies_vector
:
metainfo
=
MetaInfo
(
strategy
,
target
)
strategy
.
compute_cost
=
metainfo
.
compute_cost
strategy
.
memory_cost
=
metainfo
.
memory_cost
return
self
.
strategies_vector
class
ModuleHandler
(
NodeHandler
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -221,3 +265,26 @@ class ModuleHandler(NodeHandler):
self
.
module
=
module
self
.
named_parameters
=
named_parameters
self
.
named_buffers
=
named_buffers
class
MetaInfoModuleHandler
(
ModuleHandler
):
"""
This is a base class to handle the module patched in the meta profiler.
Note: this class will be integrated into the ModuleHandler class in the future, after
all the modules are patched.
"""
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super
().
register_strategy
(
compute_resharding_cost
=
compute_resharding_cost
)
target
=
self
.
get_target_function
()
for
strategy
in
self
.
strategies_vector
:
metainfo
=
MetaInfo
(
strategy
,
target
)
strategy
.
compute_cost
=
metainfo
.
compute_cost
strategy
.
memory_cost
=
metainfo
.
memory_cost
return
self
.
strategies_vector
colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
View file @
1cce6e36
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
ModuleHandler
from
.node_handler
import
MetaInfoModuleHandler
,
ModuleHandler
from
.registry
import
operator_registry
from
.strategy
import
NormalPoolStrategyGenerator
,
StrategyGenerator
...
...
@@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool1d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool3d
)
class
NormPoolingHandler
(
ModuleHandler
):
class
NormPoolingHandler
(
MetaInfo
ModuleHandler
):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment