Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
30e50c8b
Unverified
Commit
30e50c8b
authored
Sep 27, 2022
by
Frank Lee
Committed by
GitHub
Sep 27, 2022
Browse files
[autoparallel] implemented all matmul strategy generator (#1650)
parent
03978aad
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
440 additions
and
76 deletions
+440
-76
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
+18
-2
colossalai/auto_parallel/solver/op_handler/node_handler.py
colossalai/auto_parallel/solver/op_handler/node_handler.py
+3
-3
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+6
-0
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
...uto_parallel/solver/strategy/matmul_strategy_generator.py
+337
-42
colossalai/auto_parallel/solver/strategy/strategy_generator.py
...salai/auto_parallel/solver/strategy/strategy_generator.py
+12
-5
tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py
...auto_parallel/test_node_handler/test_linear_handler_v2.py
+64
-18
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+0
-3
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+0
-3
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
View file @
30e50c8b
...
...
@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
if
op_data
.
name
==
"weight"
:
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
if
first_dim_partition
:
dim_partition_dict
[
0
]
=
first_dim_partition
if
last_dim_partition
:
dim_partition_dict
[
-
1
]
=
last_dim_partition
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
...
...
@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
if
first_dim_partition
:
dim_partition_dict
[
0
]
=
first_dim_partition
if
last_dim_partition
:
dim_partition_dict
[
-
1
]
=
last_dim_partition
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
...
...
colossalai/auto_parallel/solver/op_handler/node_handler.py
View file @
30e50c8b
...
...
@@ -33,12 +33,12 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node.
"""
strategy_generators
=
self
.
get_strategy_generator
()
operand_mapping
=
self
.
get_operation_data_mapping
()
for
generator
in
strategy_generators
:
strategies
=
generator
.
generate
(
operand_mapping
)
strategies
=
generator
.
generate
()
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
self
.
strategies_vector
=
list
(
strategies_vector
)
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
...
...
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
30e50c8b
...
...
@@ -75,6 +75,12 @@ class OperationData:
if
self
.
logical_shape
is
None
:
self
.
logical_shape
=
self
.
data
.
shape
def
__repr__
(
self
)
->
str
:
return
f
'OperationData(name=
{
self
.
name
}
, type=
{
self
.
type
}
)'
def
__hash__
(
self
)
->
int
:
return
hash
(
f
'
{
self
.
name
}
-
{
self
.
type
}
'
)
@
dataclass
class
TrainCycleItem
:
...
...
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
View file @
30e50c8b
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/solver/strategy/strategy_generator.py
View file @
30e50c8b
...
...
@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
typing
import
Dict
,
List
,
Union
,
Any
from
..sharding_strategy
import
OperationData
,
ShardingStrategy_V2
,
TrainCycleItem
from
..sharding_strategy
import
OperationData
,
ShardingStrategy_V2
,
TrainCycleItem
,
OperationDataType
class
StrategyGenerator_V2
(
ABC
):
...
...
@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
def
is_param
(
self
,
op_data_name
):
other_data
=
self
.
op_data
[
op_data_name
]
return
other_data
.
type
==
OperationDataType
.
PARAM
def
get_sharding_strategy
(
self
,
name
:
str
,
sharding_spec_mapping
:
Dict
[
str
,
ShardingSpec
],
communication_action_mapping
:
Dict
[
str
,
CommSpec
]):
"""
...
...
@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
)
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
def
_compute_and_add
(
data
:
OperationData
,
comm_spec
:
CommSpec
):
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
...
...
@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd.
if
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
REDUCE_FWD_IDENTITY_BWD
:
if
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
ALL
REDUCE_FWD_IDENTITY_BWD
:
comm_cost
.
fwd
+=
cost
elif
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
comm_cost
.
fwd
+=
cost
...
...
@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if
strategy
.
communication_actions
is
not
None
:
for
operand
,
comm_spec
in
strategy
.
communication_actions
:
for
operand
,
comm_spec
in
strategy
.
communication_actions
.
items
()
:
_compute_and_add
(
operand
,
comm_spec
)
# update the total cost
comm_cost
.
total
=
comm_cost
.
fwd
+
comm_cost
.
bwd
# update the communication cost attribute in-place
strategy
.
communication_cost
=
comm_cost
return
strategy
...
...
@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
pass
@
abstractmethod
def
validate
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
validate
(
self
)
->
bool
:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
...
...
tests/test_auto_parallel/test_linear_handler_v2.py
→
tests/test_auto_parallel/test_
node_handler/test_
linear_handler_v2.py
View file @
30e50c8b
...
...
@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
0
,
2
0
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
6
,
3
2
).
to
(
'meta'
))
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
0
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
6
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
...
@@ -34,32 +34,55 @@ def test_linear_module_handler():
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
2
0
,
1
0
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
3
2
,
1
6
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
2
0
])
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
3
2
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
2
0
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
3
2
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
1
0
,
2
0
).
to
(
'meta'
)
model
=
nn
.
Linear
(
1
6
,
3
2
).
to
(
'meta'
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
0
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
6
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
...
@@ -77,27 +100,50 @@ def test_linear_function_handler():
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
2
0
,
1
0
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
3
2
,
1
6
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
2
0
])
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
3
2
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
2
0
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
3
2
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_linear_module_handler
()
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
View file @
30e50c8b
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
30e50c8b
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment