Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
51b89d22
Unverified
Commit
51b89d22
authored
Oct 18, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 18, 2022
Browse files
[autoparallel] runtime_backward_apply (#1720)
parent
393f5940
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
27 deletions
+56
-27
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
...x/passes/experimental/adding_shape_consistency_pass_v2.py
+48
-27
colossalai/tensor/shape_consistency.py
colossalai/tensor/shape_consistency.py
+8
-0
No files found.
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
View file @
51b89d22
from
ast
import
NodeTransformer
import
torch
from
typing
import
List
from
torch.fx
import
symbolic_trace
...
...
@@ -10,10 +11,32 @@ import builtins
import
operator
from
copy
import
deepcopy
shape_consistency_manager
=
ShapeConsistencyManager
()
def
apply
(
*
args
,
**
kwargs
):
shape_consistency_manager
=
ShapeConsistencyManager
()
return
shape_consistency_manager
.
apply
(
*
args
,
**
kwargs
)
class
ConsistencyApply
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
ctx
.
origin_sharding_spec
=
origin_dict
[
node_index
]
ctx
.
target_sharding_spec
=
input_dict
[
node_index
][
user_node_index
]
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
ctx
.
origin_sharding_spec
,
ctx
.
target_sharding_spec
)
@
staticmethod
def
backward
(
ctx
,
node_grad
):
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node_grad
,
ctx
.
target_sharding_spec
,
ctx
.
origin_sharding_spec
),
None
,
None
,
None
,
None
def
runtime_apply_for_leaf_node
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
return
ConsistencyApply
.
apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
)
def
runtime_apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
origin_sharding_spec
=
origin_dict
[
node_index
]
target_sharding_spec
=
input_dict
[
node_index
][
user_node_index
]
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
):
...
...
@@ -37,21 +60,19 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
apply
(
param
,
target_sharding_spec
)
shape_consistency_manager
.
apply
(
param
,
target_sharding_spec
)
for
name
,
buffer
in
target_module
.
named_buffers
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
setattr
(
buffer
,
'sharding_spec'
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
apply
(
buffer
,
target_sharding_spec
)
shape_consistency_manager
.
apply
(
buffer
,
target_sharding_spec
)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
target_sharding_specs
=
[]
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
...
...
@@ -91,28 +112,28 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
# add shape consistency apply function into graph
for
node
in
nodes
:
if
not
hasattr
(
node
,
'best_strategy'
):
if
not
hasattr
(
node
,
'best_strategy'
)
or
node
.
op
==
'output'
:
continue
with
mod_graph
.
inserting_after
(
node
):
origin_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
origin_dict_node
,
node_to_index_dict
[
node
]))
with
mod_graph
.
inserting_after
(
origin_spec_node
):
set_sharding_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
builtins
.
setattr
,
args
=
(
node
,
'sharding_spec'
,
origin_spec_node
))
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
node_index
=
user_node
.
strategies_vector
.
predecessor_nodes
.
index
(
node
)
with
mod_graph
.
inserting_before
(
user_node
):
input_specs_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
input_dict_node
,
node_to_index_dict
[
node
]))
user_node_index
=
user_node
.
strategies_vector
.
predecessor_nodes
.
index
(
node
)
if
user_node
.
op
!=
"output"
:
with
mod_graph
.
inserting_before
(
user_node
):
sharding_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
input_specs_node
,
node_index
))
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
else
:
# we need to call an autograd.Function for leaf node
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
apply
,
args
=
(
node
,
sharding_spec_node
))
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_apply_for_leaf_node
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
origin_index_args
=
user_node
.
args
.
index
(
node
)
new_args
=
list
(
user_node
.
args
)
new_args
[
origin_index_args
]
=
shape_consistency_node
user_node
.
args
=
new_args
return
gm
colossalai/tensor/shape_consistency.py
View file @
51b89d22
...
...
@@ -498,3 +498,11 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
for
comm_spec
in
comm_action_sequence
:
comm_spec
.
covert_spec_to_action
(
tensor_with_sharding_spec
)
tensor_with_sharding_spec
.
sharding_spec
=
target_spec
return
tensor_with_sharding_spec
def
apply_for_autoparallel_runtime
(
self
,
tensor
,
source_spec
,
target_spec
):
_
,
comm_action_sequence
,
_
=
self
.
shape_consistency
(
source_spec
,
target_spec
)
for
comm_spec
in
comm_action_sequence
:
comm_spec
.
covert_spec_to_action
(
tensor
)
tensor
.
sharding_spec
=
target_spec
return
tensor
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment