Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
50f16a28
Unverified
Commit
50f16a28
authored
Sep 28, 2022
by
Frank Lee
Committed by
GitHub
Sep 28, 2022
Browse files
[autoparallel] added compute resharding costs for node handler (#1662)
parent
9ec401a7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
2 deletions
+56
-2
colossalai/auto_parallel/solver/op_handler/node_handler.py
colossalai/auto_parallel/solver/op_handler/node_handler.py
+43
-2
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+13
-0
No files found.
colossalai/auto_parallel/solver/op_handler/node_handler.py
View file @
50f16a28
from
abc
import
ABC
,
abstractmethod
from
torch.fx.node
import
Node
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
typing
import
Dict
,
List
from
..sharding_strategy
import
ShardingStrategy_V2
,
StrategiesVector
,
OperationData
from
..sharding_strategy
import
ShardingStrategy_V2
,
StrategiesVector
,
OperationData
,
TrainCycleItem
from
..strategy
import
StrategyGenerator_V2
...
...
@@ -28,13 +29,53 @@ class NodeHandler(ABC):
self
.
device_mesh
=
device_mesh
self
.
strategies_vector
=
strategies_vector
def
register_strategy
(
self
)
->
StrategiesVector
:
def
update_resharding_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
None
:
"""
Compute the resharding costs and save the costs in the ShardingStrategy object.
"""
# TODO: test this function when other handlers are ready
resharding_costs
=
{}
shape_consistency_manager
=
ShapeConsistencyManager
()
for
node
in
self
.
predecessor_node
:
node_name
=
str
(
node
)
# get the sharding specs for this node generated
# in its own node handler
assert
hasattr
(
node
,
'strategies_vector'
),
\
f
'The predecessor node
{
node_name
}
has no strategy vector to compute the resharding cost.'
prev_strategy_vector
=
node
.
strategies_vector
prev_sharding_specs
=
[
strategy
.
get_sharding_spec_by_name
(
node_name
)
for
strategy
in
prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data
=
strategy
.
get_op_data_by_name
(
node_name
)
current_sharding_spec
=
strategy
.
sharding_specs
[
op_data
]
# create data structrure to store costs
if
op_data
not
in
resharding_costs
:
resharding_costs
[
op_data
]
=
{}
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for
prev_sharding_spec
in
prev_sharding_specs
:
fwd_cost
=
shape_consistency_manager
.
shape_consistency
(
prev_sharding_spec
,
current_sharding_spec
)
bwd_cost
=
shape_consistency_manager
.
shape_consistency
(
current_sharding_spec
,
prev_sharding_spec
)
resharding_cost
=
TrainCycleItem
(
fwd
=
fwd_cost
,
bwd
=
bwd_cost
,
total
=
fwd_cost
+
bwd_cost
)
resharding_costs
[
op_data
][
prev_sharding_spec
]
=
resharding_cost
strategy
.
resharding_costs
=
resharding_costs
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
False
)
->
StrategiesVector
:
"""
Register different sharding strategies for the current node.
"""
strategy_generators
=
self
.
get_strategy_generator
()
for
generator
in
strategy_generators
:
strategies
=
generator
.
generate
()
# compute the resharding costs based on the previous node
# strategies if specified
if
compute_resharding_cost
:
strategies
=
list
(
map
(
self
.
update_resharding_cost
,
strategies
))
self
.
strategies_vector
.
extend
(
strategies
)
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
...
...
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
50f16a28
...
...
@@ -129,6 +129,7 @@ class ShardingStrategy_V2:
memory_cost
:
TrainCycleItem
=
None
input_resharding_costs
:
Dict
[
OperationData
,
List
[
float
]]
=
None
communication_actions
:
Dict
[
OperationData
,
CommSpec
]
=
None
resharding_costs
:
Dict
[
OperationData
,
Dict
[
ShardingSpec
,
TrainCycleItem
]]
=
None
@
property
def
input_sharding_specs
(
self
)
->
Dict
[
OperationData
,
ShardingSpec
]:
...
...
@@ -153,6 +154,18 @@ class ShardingStrategy_V2:
specs
=
{
k
:
v
for
k
,
v
in
self
.
sharding_specs
.
items
()
if
k
.
type
==
operation_data_type
}
return
specs
def
get_op_data_by_name
(
self
,
name
:
str
):
for
op_data
in
self
.
sharding_specs
.
keys
():
if
op_data
.
name
==
name
:
return
op_data
raise
KeyError
(
f
"Could not find the OperationData with name
{
name
}
"
)
def
get_sharding_spec_by_name
(
self
,
name
:
str
):
for
op_data
,
sharding_spec
in
self
.
sharding_specs
.
items
():
if
op_data
.
name
==
name
:
return
sharding_spec
raise
KeyError
(
f
"Could not find the ShardingSpec for OperationData with name
{
name
}
"
)
class
StrategiesVector
(
list
):
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment