Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0e9db368
Unverified
Commit
0e9db368
authored
Dec 06, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 06, 2022
Browse files
[autoparallel] add tensor constructor handler (#2082)
parent
cdf537a6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
171 additions
and
2 deletions
+171
-2
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+1
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
...o_parallel/tensor_shard/node_handler/strategy/__init__.py
+3
-1
colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
...ard/node_handler/strategy/tensor_constructor_generator.py
+67
-0
colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
...l/tensor_shard/node_handler/tensor_constructor_handler.py
+32
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
...tensor_shard/test_node_handler/test_tensor_constructor.py
+66
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
0e9db368
...
@@ -14,6 +14,7 @@ from .output_handler import OuputHandler
...
@@ -14,6 +14,7 @@ from .output_handler import OuputHandler
from
.placeholder_handler
import
PlacehodlerHandler
from
.placeholder_handler
import
PlacehodlerHandler
from
.registry
import
operator_registry
from
.registry
import
operator_registry
from
.reshape_handler
import
ReshapeHandler
from
.reshape_handler
import
ReshapeHandler
from
.tensor_constructor_handler
import
TensorConstructorHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.unary_elementwise_handler
import
UnaryElementwiseHandler
from
.where_handler
import
WhereHandler
from
.where_handler
import
WhereHandler
...
@@ -22,5 +23,5 @@ __all__ = [
...
@@ -22,5 +23,5 @@ __all__ = [
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
0e9db368
...
@@ -11,6 +11,7 @@ __all__ = ['ReshapeHandler']
...
@@ -11,6 +11,7 @@ __all__ = ['ReshapeHandler']
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
unsqueeze
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
class
ReshapeHandler
(
NodeHandler
):
class
ReshapeHandler
(
NodeHandler
):
"""
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
View file @
0e9db368
...
@@ -15,6 +15,7 @@ from .output_generator import OutputGenerator
...
@@ -15,6 +15,7 @@ from .output_generator import OutputGenerator
from
.placeholder_generator
import
PlaceholderGenerator
from
.placeholder_generator
import
PlaceholderGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
from
.tensor_constructor_generator
import
TensorConstructorGenerator
from
.unary_elementwise_generator
import
UnaryElementwiseGenerator
from
.unary_elementwise_generator
import
UnaryElementwiseGenerator
from
.where_generator
import
WhereGenerator
from
.where_generator
import
WhereGenerator
...
@@ -23,5 +24,6 @@ __all__ = [
...
@@ -23,5 +24,6 @@ __all__ = [
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
,
'ReshapeGenerator'
,
'PlaceholderGenerator'
,
'OutputGenerator'
,
'WhereGenerator'
,
'LayerNormGenerator'
,
'ReshapeGenerator'
,
'PlaceholderGenerator'
,
'OutputGenerator'
,
'WhereGenerator'
,
'ReshapeGenerator'
,
'NormalPoolStrategyGenerator'
,
'BinaryElementwiseStrategyGenerator'
,
'GetattrGenerator'
'ReshapeGenerator'
,
'NormalPoolStrategyGenerator'
,
'BinaryElementwiseStrategyGenerator'
,
'GetattrGenerator'
,
'TensorConstructorGenerator'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
0 → 100644
View file @
0e9db368
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'TensorConstructorGenerator'
]
class
TensorConstructorGenerator
(
StrategyGenerator
):
"""
TensorConstructorGenerator which deals with
the sharding strategies for tensor constructor operation, such as torch.arange.
"""
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
compute_cost
=
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping
=
{
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
bwd_mem_cost
=
MemoryCost
(
activation
=
0
,
parameter
=
0
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
dim_partition_dict_mapping
=
{
"output"
:
{},
}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'Replica Tensor Constructor'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/tensor_constructor_handler.py
0 → 100644
View file @
0e9db368
from
typing
import
Dict
,
List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
StrategyGenerator
from
.strategy.tensor_constructor_generator
import
TensorConstructorGenerator
__all__
=
[
'TensorConstructorHandler'
]
@
operator_registry
.
register
(
torch
.
arange
)
class
TensorConstructorHandler
(
NodeHandler
):
"""
A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
TensorConstructorGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"output"
:
physical_output_operand
}
return
mapping
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py
0 → 100644
View file @
0e9db368
import
torch
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler
import
TensorConstructorHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
class
TensorConstructorModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
arange_node
=
torch
.
arange
(
x
.
size
()[
0
])
x
=
x
+
arange_node
return
x
def
test_where_handler
():
model
=
TensorConstructorModel
()
tracer
=
ColoTracer
()
# graph():
# %x : torch.Tensor [#users=2] = placeholder[target=x]
# %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {})
# %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {})
# return add
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x"
:
torch
.
rand
(
10
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
arange_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
arange_node
)
# build handler
handler
=
TensorConstructorHandler
(
node
=
arange_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'output'
].
name
==
"arange"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
10
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
assert
'Replica Tensor Constructor'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_where_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment