Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8221fd74
Unverified
Commit
8221fd74
authored
Jan 12, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 12, 2023
Browse files
[autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler * polish
parent
c9ec5190
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
22 deletions
+73
-22
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
...l/tensor_shard/node_handler/binary_elementwise_handler.py
+22
-5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
...hard/test_node_handler/test_binary_elementwise_handler.py
+48
-15
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+3
-2
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
View file @
8221fd74
...
@@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
...
@@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
return
OperationDataType
.
ARG
return
OperationDataType
.
ARG
def
_get_arg_value
(
idx
):
def
_get_arg_value
(
idx
):
non_tensor
=
False
if
isinstance
(
self
.
node
.
args
[
idx
],
Node
):
if
isinstance
(
self
.
node
.
args
[
idx
],
Node
):
meta_data
=
self
.
node
.
args
[
idx
].
_meta_data
meta_data
=
self
.
node
.
args
[
idx
].
_meta_data
# The meta_data of node type argument could also possibly be a non-tensor object.
if
not
isinstance
(
meta_data
,
torch
.
Tensor
):
assert
isinstance
(
meta_data
,
(
int
,
float
))
meta_data
=
torch
.
Tensor
([
meta_data
]).
to
(
'meta'
)
non_tensor
=
True
else
:
else
:
# this is in fact a real data like int 1
# this is in fact a real data like int 1
# but we can deem it as meta data
# but we can deem it as meta data
# as it won't affect the strategy generation
# as it won't affect the strategy generation
assert
isinstance
(
self
.
node
.
args
[
idx
],
(
int
,
float
))
assert
isinstance
(
self
.
node
.
args
[
idx
],
(
int
,
float
))
meta_data
=
torch
.
Tensor
([
self
.
node
.
args
[
idx
]]).
to
(
'meta'
)
meta_data
=
torch
.
Tensor
([
self
.
node
.
args
[
idx
]]).
to
(
'meta'
)
return
meta_data
non_tensor
=
True
input_meta_data
=
_get_arg_value
(
0
)
return
meta_data
,
non_tensor
other_meta_data
=
_get_arg_value
(
1
)
output_meta_data
=
self
.
node
.
_meta_data
input_meta_data
,
non_tensor_input
=
_get_arg_value
(
0
)
other_meta_data
,
non_tensor_other
=
_get_arg_value
(
1
)
output_meta_data
=
self
.
node
.
_meta_data
# we need record op_data with non-tensor data in this list,
# and filter the non-tensor op_data in post_process.
self
.
non_tensor_list
=
[]
# assert False
input_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
input_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
_get_op_data_type
(
input_meta_data
),
type
=
_get_op_data_type
(
input_meta_data
),
data
=
input_meta_data
,
data
=
input_meta_data
,
...
@@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
...
@@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
type
=
OperationDataType
.
OUTPUT
,
type
=
OperationDataType
.
OUTPUT
,
data
=
output_meta_data
,
data
=
output_meta_data
,
logical_shape
=
bcast_shape
)
logical_shape
=
bcast_shape
)
if
non_tensor_input
:
self
.
non_tensor_list
.
append
(
input_op_data
)
if
non_tensor_other
:
self
.
non_tensor_list
.
append
(
other_op_data
)
mapping
=
{
'input'
:
input_op_data
,
'other'
:
other_op_data
,
'output'
:
output_op_data
}
mapping
=
{
'input'
:
input_op_data
,
'other'
:
other_op_data
,
'output'
:
output_op_data
}
return
mapping
return
mapping
...
@@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
...
@@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
op_data_mapping
=
self
.
get_operation_data_mapping
()
op_data_mapping
=
self
.
get_operation_data_mapping
()
for
op_name
,
op_data
in
op_data_mapping
.
items
():
for
op_name
,
op_data
in
op_data_mapping
.
items
():
if
not
isinstance
(
op_data
.
data
,
torch
.
Tensor
)
:
if
op_data
in
self
.
non_tensor_list
:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy
.
sharding_specs
.
pop
(
op_data
)
strategy
.
sharding_specs
.
pop
(
op_data
)
else
:
else
:
# convert the logical sharding spec to physical sharding spec if broadcast
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)
# e.g. torch.rand(4, 4) + torch.rand(4)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
View file @
8221fd74
...
@@ -122,11 +122,19 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
...
@@ -122,11 +122,19 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
def
check_binary_elementwise_handler_with_int
(
rank
,
op
,
other_dim
,
world_size
,
port
):
class
BEOpModelWithNodeConst
(
nn
.
Module
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
class
BinaryElementwiseOpModel
(
nn
.
Module
):
def
__init__
(
self
,
op
):
super
().
__init__
()
self
.
op
=
op
def
forward
(
self
,
x1
):
const
=
x1
.
dim
()
out
=
self
.
op
(
x1
,
const
)
return
out
class
BEOpModelWithIntConst
(
nn
.
Module
):
def
__init__
(
self
,
op
,
const
):
def
__init__
(
self
,
op
,
const
):
super
().
__init__
()
super
().
__init__
()
...
@@ -137,10 +145,18 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
...
@@ -137,10 +145,18 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
out
=
self
.
op
(
x1
,
self
.
const
)
out
=
self
.
op
(
x1
,
self
.
const
)
return
out
return
out
def
check_binary_elementwise_handler_with_int
(
rank
,
op
,
other_dim
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
model
=
BinaryElementwiseOpModel
(
op
,
other_dim
).
cuda
()
if
model_cls
==
BEOpModelWithNodeConst
:
model
=
model_cls
(
op
).
cuda
()
else
:
model
=
model_cls
(
op
,
other_dim
).
cuda
()
x1
=
torch
.
rand
(
4
,
4
).
cuda
()
x1
=
torch
.
rand
(
4
,
4
).
cuda
()
# the index of binary-elementwise node in computation graph
# the index of binary-elementwise node in computation graph
node_index
=
1
node_index
=
1
...
@@ -159,8 +175,13 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
...
@@ -159,8 +175,13 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
meta_args
=
{
'x1'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)}
meta_args
=
{
'x1'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
model
,
meta_args
=
meta_args
)
graph
=
tracer
.
trace
(
model
,
meta_args
=
meta_args
)
print
(
graph
)
# assert False
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
if
model_cls
==
BEOpModelWithNodeConst
:
op_node
=
list
(
graph
.
nodes
)[
2
]
else
:
op_node
=
list
(
graph
.
nodes
)[
1
]
op_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
op_node
)
strategies_vector
=
StrategiesVector
(
op_node
)
...
@@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
...
@@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_binary_elementwise_handler
(
op
,
other_dim
):
def
test_binary_elementwise_handler
_with_tensor
(
op
,
other_dim
):
world_size
=
4
world_size
=
4
run_func_tensor
=
partial
(
check_binary_elementwise_handler_with_tensor
,
run_func_tensor
=
partial
(
check_binary_elementwise_handler_with_tensor
,
op
=
op
,
op
=
op
,
...
@@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
...
@@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
world_size
=
world_size
,
world_size
=
world_size
,
port
=
free_port
())
port
=
free_port
())
mp
.
spawn
(
run_func_tensor
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func_tensor
,
nprocs
=
world_size
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'op'
,
[
torch
.
add
])
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
parameterize
(
'model_cls'
,
[
BEOpModelWithNodeConst
,
BEOpModelWithIntConst
])
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_binary_elementwise_handler_with_int
(
op
,
model_cls
,
other_dim
):
world_size
=
4
run_func_int
=
partial
(
check_binary_elementwise_handler_with_int
,
run_func_int
=
partial
(
check_binary_elementwise_handler_with_int
,
op
=
op
,
op
=
op
,
model_cls
=
model_cls
,
other_dim
=
other_dim
,
other_dim
=
other_dim
,
world_size
=
world_size
,
world_size
=
world_size
,
port
=
free_port
())
port
=
free_port
())
...
@@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
...
@@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_binary_elementwise_handler
()
test_binary_elementwise_handler_with_tensor
()
test_binary_elementwise_handler_with_int
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
8221fd74
...
@@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
...
@@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solver_options
=
SolverOptions
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
target_node
=
list
(
graph
.
nodes
)[
node_index
]
target_node
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
][
node_index
]
if
node_type
==
'normal'
:
if
node_type
==
'normal'
:
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution
=
[
0
]
*
solution_len
solution
=
[
0
]
*
solution_len
...
@@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
...
@@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
=
runtime_apply_pass
(
gm
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
gm
.
recompile
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment