Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
59f10051
Unverified
Commit
59f10051
authored
Sep 27, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 27, 2022
Browse files
[autoparallel] where handler (#1651)
* [autoparallel] where handler * fix unit test
parent
6135e178
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
246 additions
and
0 deletions
+246
-0
colossalai/auto_parallel/solver/op_handler/where_handler.py
colossalai/auto_parallel/solver/op_handler/where_handler.py
+181
-0
tests/test_auto_parallel/test_where_handler.py
tests/test_auto_parallel/test_where_handler.py
+65
-0
No files found.
colossalai/auto_parallel/solver/op_handler/where_handler.py
0 → 100644
View file @
59f10051
import
operator
from
functools
import
reduce
import
warnings
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
colossalai.auto_parallel.solver._utils
import
exception_handler
,
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
__all__
=
[
'WhereHandler'
]
class
WhereHandler
(
OperatorHandler
):
"""
An OperatorHandler which deals with the sharding strategies of torch.where.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# TODO: x or y could be scalar
super
().
__init__
(
*
args
,
**
kwargs
)
assert
len
(
self
.
predecessor_node
)
==
3
self
.
condition_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
x_data
=
self
.
predecessor_node
[
1
].
_meta_data
self
.
y_data
=
self
.
predecessor_node
[
2
].
_meta_data
self
.
condition
=
self
.
predecessor_node
[
0
]
self
.
x
=
self
.
predecessor_node
[
1
]
self
.
y
=
self
.
predecessor_node
[
2
]
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_sharding_spec
(
self
,
input_
:
torch
.
Tensor
,
dim_partition_dict
:
Dict
[
int
,
List
[
int
]])
->
ShardingSpec
:
shape
=
list
(
input_
.
shape
)
# padding the shape to the same length as output_data
while
len
(
shape
)
<
self
.
output_data
.
dim
():
shape
.
insert
(
0
,
1
)
shape
=
torch
.
Size
(
shape
)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict
=
deepcopy
(
dim_partition_dict
)
for
dim_index
,
_
in
dim_partition_dict
.
items
():
if
shape
[
dim_index
]
==
1
:
processed_dim_partition_dict
.
pop
(
dim_index
)
for
dim_index
,
sharding_index_list
in
processed_dim_partition_dict
.
items
():
sharding_list
=
[
self
.
device_mesh
.
mesh_shape
[
sharding_index
]
for
sharding_index
in
sharding_index_list
]
sharding_size
=
reduce
(
operator
.
mul
,
sharding_list
,
1
)
assert
shape
[
dim_index
]
%
sharding_size
==
0
,
f
'we cannot shard the
{
dim_index
}
dimension of tensor into
{
sharding_size
}
partitions.'
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
shape
,
dim_partition_dict
=
processed_dim_partition_dict
)
return
sharding_spec
def
_generate_compute_cost
(
self
,
total_sharding_size
):
lhs_matrix_shape
=
self
.
lhs_data
.
shape
[
-
2
:]
rhs_matrix_shape
=
self
.
rhs_data
.
shape
[
-
2
:]
batch_dimensions_shape
=
self
.
output_data
.
shape
[:
-
2
]
batch_dimensions_product
=
reduce
(
operator
.
mul
,
batch_dimensions_shape
,
1
)
compute_cost
=
reduce
(
operator
.
mul
,
lhs_matrix_shape
)
*
rhs_matrix_shape
[
0
]
*
batch_dimensions_product
*
2
/
total_sharding_size
return
compute_cost
def
_generate_resharding_costs
(
self
,
sharding_specs
):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype
=
self
.
node
.
_meta_data
.
dtype
nodes
=
self
.
predecessor_node
resharding_costs
=
{}
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# shape consistency manager is a singleton class
shape_consistency_manager
=
ShapeConsistencyManager
()
for
input_node
,
input_spec
in
zip
(
nodes
,
sharding_specs
):
resharding_costs
[
input_node
]
=
[]
for
strategy
in
input_node
.
strategies_vector
:
input_sharding_spec
=
strategy
.
output_sharding_spec
assert
isinstance
(
input_sharding_spec
,
ShardingSpec
),
f
'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if
len
(
input_sharding_spec
.
entire_shape
)
<
len
(
input_spec
.
entire_shape
):
new_entire_shape
=
list
(
input_sharding_spec
.
entire_shape
)
while
len
(
new_entire_shape
)
<
len
(
input_spec
.
entire_shape
):
new_entire_shape
.
insert
(
0
,
1
)
new_entire_shape
=
torch
.
Size
(
new_entire_shape
)
new_device_mesh
=
input_sharding_spec
.
device_mesh
new_dim_partition_dict
=
input_sharding_spec
.
dim_partition_dict
input_sharding_spec
=
ShardingSpec
(
device_mesh
=
new_device_mesh
,
entire_shape
=
new_entire_shape
,
dim_partition_dict
=
new_dim_partition_dict
)
# compute the resharding cost
_
,
_
,
total_resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
input_spec
)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost
=
total_resharding_cost
*
size_per_elem_bytes
resharding_costs
[
input_node
].
append
(
resharding_cost
)
return
resharding_costs
def
_convert_partition_dict_to_sharding_spec
(
self
,
dim_partition_list
):
sharding_spec_list
=
[]
check_duplicated_list
=
[]
for
output_dim_partition_dict
in
dim_partition_list
:
try
:
output_sharding_spec
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
output_dim_partition_dict
)
except
AssertionError
as
e
:
warnings
.
warn
(
f
'
{
e
}
'
)
break
sharding_seq
=
output_sharding_spec
.
sharding_sequence
if
sharding_seq
not
in
check_duplicated_list
:
check_duplicated_list
.
append
(
sharding_seq
)
sharding_spec_list
.
append
(
output_sharding_spec
)
return
sharding_spec_list
def
_enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list
=
[]
dim_size
=
self
.
output_data
.
dim
()
# enumerate all the 2D sharding cases
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_2d
)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0
=
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_1d_on_dim_0
)
sharding_list_1d_on_dim_1
=
enumerate_all_possible_1d_sharding
(
mesh_dim_1
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_1d_on_dim_1
)
# add empty dict for fully replicated case
output_dim_partition_list
.
append
({})
output_sharding_spec_list
=
self
.
_convert_partition_dict_to_sharding_spec
(
output_dim_partition_list
)
return
output_sharding_spec_list
@
exception_handler
def
_register_strategy
(
self
,
output_sharding_spec
):
dim_partition_dict_for_input
=
output_sharding_spec
.
dim_partition_dict
sharding_spec_for_condition
=
self
.
_generate_sharding_spec
(
self
.
condition_data
,
dim_partition_dict_for_input
)
sharding_spec_for_x
=
self
.
_generate_sharding_spec
(
self
.
x_data
,
dim_partition_dict_for_input
)
sharding_spec_for_y
=
self
.
_generate_sharding_spec
(
self
.
y_data
,
dim_partition_dict_for_input
)
name
=
f
'
{
output_sharding_spec
.
sharding_sequence
}
=
{
sharding_spec_for_condition
.
sharding_sequence
}
x
{
sharding_spec_for_x
.
sharding_sequence
}
x
{
sharding_spec_for_y
.
sharding_sequence
}
'
dim_partition_dict_for_output
=
output_sharding_spec
.
dim_partition_dict
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
(
[
sharding_spec_for_condition
,
sharding_spec_for_x
,
sharding_spec_for_y
])
# compute the computation cost of this strategy
sharding_dims
=
[]
for
mesh_dims
in
dim_partition_dict_for_output
.
values
():
for
mesh_dim
in
mesh_dims
:
sharding_dims
.
append
(
self
.
device_mesh
.
shape
[
mesh_dim
])
sharding_size
=
reduce
(
operator
.
mul
,
sharding_dims
,
1
)
memory_cost
=
self
.
output_data
.
numel
()
/
sharding_size
compute_cost
=
memory_cost
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
output_sharding_spec
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_condition
,
sharding_spec_for_x
,
sharding_spec_for_y
))
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
register_strategy
(
self
)
->
StrategiesVector
:
MESH_DIM_LIST
=
[
0
,
1
]
output_sharding_specs
=
self
.
_enumerate_all_possible_output
(
MESH_DIM_LIST
[
0
],
MESH_DIM_LIST
[
1
])
for
output_sharding_spec
in
output_sharding_specs
:
self
.
_register_strategy
(
output_sharding_spec
)
tests/test_auto_parallel/test_where_handler.py
0 → 100644
View file @
59f10051
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.auto_parallel.solver.options
import
SolverOptions
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
class
ConvModel
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
dim_out
):
super
().
__init__
()
self
.
dim_in
=
dim_in
self
.
dim_out
=
dim_out
def
forward
(
self
,
condition
,
x
,
y
):
output
=
torch
.
where
(
condition
,
x
,
y
)
return
output
@
pytest
.
mark
.
skip
(
"temporarily skipped"
)
def
test_where_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
tracer
=
ColoTracer
()
model
=
ConvModel
(
16
,
32
)
input_sample
=
{
'condition'
:
torch
.
rand
(
16
,
32
).
to
(
'meta'
),
'x'
:
torch
.
rand
(
16
,
32
).
to
(
'meta'
),
'y'
:
torch
.
rand
(
16
,
32
).
to
(
'meta'
)
}
# graph():
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %y : torch.Tensor [#users=1] = placeholder[target=y]
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
# return where
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# [condition, x, y, where, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategy_map
=
strategies_constructor
.
strategy_map
# check a tensor add with a scalar case
where_node
=
strategy_map
[
nodes
[
3
]]
# ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]',
# '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]',
# '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]']
assert
len
(
where_node
)
==
9
if
__name__
==
'__main__'
:
test_where_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment