Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3ccf58aa
Unverified
Commit
3ccf58aa
authored
Jan 02, 2023
by
Super Daniel
Committed by
GitHub
Jan 02, 2023
Browse files
[autockpt] make it work. (#2257)
parent
ac373993
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
12 deletions
+12
-12
colossalai/auto_parallel/passes/comm_metainfo_pass.py
colossalai/auto_parallel/passes/comm_metainfo_pass.py
+7
-7
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
...l/tensor_shard/node_handler/binary_elementwise_handler.py
+1
-1
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+2
-2
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
...el/tensor_shard/node_handler/unary_elementwise_handler.py
+2
-2
No files found.
colossalai/auto_parallel/passes/comm_metainfo_pass.py
View file @
3ccf58aa
...
...
@@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
return
meta_info
def
_runtime_apply_meta_info
(
node
:
Node
,
origin
al_sharding
_spec_dict
,
sharding_spec_dict
)
->
MetaInfo
:
def
_runtime_apply_meta_info
(
node
:
Node
,
origin_spec_dict
,
sharding_spec_dict
)
->
MetaInfo
:
"""
This method is used to construct `MetaInto` for shape consistency node
"""
...
...
@@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s
# extract node index and user node index
args
=
node
.
args
node_index
,
user_node_index
=
args
[
3
],
args
[
4
]
origin_sharding_spec
,
target_sharding_spec
=
origin
al_sharding
_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
]
origin_sharding_spec
,
target_sharding_spec
=
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
]
return
_construct_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
...
...
@@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
return
meta_info
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin
al_sharding
_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
):
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
)
->
GraphModule
:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for
node
in
gm
.
graph
.
nodes
:
if
node
.
target
==
runtime_apply
:
setattr
(
node
,
'best_metainfo'
,
_runtime_apply_meta_info
(
node
,
original_sharding_spec_dict
,
sharding_spec_dict
))
setattr
(
node
,
'best_metainfo'
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
elif
node
.
target
==
runtime_comm_spec_apply
:
setattr
(
node
,
'best_metainfo'
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
else
:
pass
return
gm
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
View file @
3ccf58aa
...
...
@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@
operator_registry
.
register
(
BCAST_FUNC_OP
)
class
BinaryElementwiseHandler
(
NodeHandler
):
class
BinaryElementwiseHandler
(
MetaInfo
NodeHandler
):
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
3ccf58aa
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.node_handler
import
MetaInfoNodeHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
ReshapeGenerator
,
StrategyGenerator
...
...
@@ -13,7 +13,7 @@ __all__ = ['ReshapeHandler']
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
unsqueeze
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
class
ReshapeHandler
(
NodeHandler
):
class
ReshapeHandler
(
MetaInfo
NodeHandler
):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
View file @
3ccf58aa
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.node_handler
import
MetaInfoNodeHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
StrategyGenerator
,
UnaryElementwiseGenerator
...
...
@@ -19,7 +19,7 @@ __all__ = ['UnaryElementwiseHandler']
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
@
operator_registry
.
register
(
torch
.
Tensor
.
contiguous
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
dropout
)
class
UnaryElementwiseHandler
(
NodeHandler
):
class
UnaryElementwiseHandler
(
MetaInfo
NodeHandler
):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment