Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
05020e50
Unverified
Commit
05020e50
authored
Nov 18, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 18, 2022
Browse files
[autoparallel] support more flexible data type (#1967)
parent
5bec3b21
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
74 additions
and
14 deletions
+74
-14
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+3
-1
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+4
-0
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+31
-2
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+34
-11
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
...el/tensor_shard/node_handler/unary_elementwise_handler.py
+2
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
05020e50
...
@@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler
...
@@ -4,6 +4,7 @@ from .binary_elementwise_handler import BinaryElementwiseHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.getatrr_handler
import
GetattrHandler
from
.getatrr_handler
import
GetattrHandler
from
.getitem_handler
import
GetItemHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.matmul_handler
import
MatMulHandler
...
@@ -19,5 +20,6 @@ __all__ = [
...
@@ -19,5 +20,6 @@ __all__ = [
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetattrHandler'
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
05020e50
...
@@ -51,6 +51,10 @@ class NodeHandler(ABC):
...
@@ -51,6 +51,10 @@ class NodeHandler(ABC):
for
node
in
self
.
predecessor_node
:
for
node
in
self
.
predecessor_node
:
node_name
=
str
(
node
)
node_name
=
str
(
node
)
# get the current sharding spec generated by this node handler
# get the current sharding spec generated by this node handler
# TODO: we need to check this in future
if
not
isinstance
(
node
.
_meta_data
,
torch
.
Tensor
):
continue
op_data
=
strategy
.
get_op_data_by_name
(
node_name
)
op_data
=
strategy
.
get_op_data_by_name
(
node_name
)
current_sharding_spec
=
strategy
.
sharding_specs
[
op_data
]
current_sharding_spec
=
strategy
.
sharding_specs
[
op_data
]
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
05020e50
...
@@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler']
...
@@ -11,7 +11,9 @@ __all__ = ['ReshapeHandler']
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
...
@@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler):
...
@@ -26,6 +28,24 @@ class ReshapeHandler(NodeHandler):
generators
.
append
(
ReshapeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
generators
.
append
(
ReshapeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
return
generators
def
infer_logical_shape
(
self
,
data
):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
.
shape
else
:
assert
isinstance
(
data
,
tuple
),
"input_data should be a tuple of tensor or a tensor."
logical_shape
=
[]
for
tensor
in
data
:
assert
isinstance
(
tensor
,
torch
.
Tensor
),
"input_data should be a tuple of tensor or a tensor."
logical_shape
.
append
(
tensor
.
shape
)
logical_shape
=
tuple
(
logical_shape
)
return
logical_shape
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# the strategies will be transformed back to its original shape in self.post_process
...
@@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler):
...
@@ -36,10 +56,19 @@ class ReshapeHandler(NodeHandler):
else
:
else
:
data_type
=
OperationDataType
.
ARG
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
input_logical_shape
=
self
.
infer_logical_shape
(
input_data
)
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
type
=
data_type
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
data
=
input_data
,
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
logical_shape
=
input_logical_shape
)
output_data
=
self
.
node
.
_meta_data
output_logical_shape
=
self
.
infer_logical_shape
(
output_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
"input"
:
physical_input_operand
,
"output"
:
physical_output
}
mapping
=
{
"input"
:
physical_input_operand
,
"output"
:
physical_output
}
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
05020e50
...
@@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
...
@@ -81,9 +81,10 @@ class StrategyGenerator(ABC):
for
logical_shape
,
dim_partition_dict_element
in
zip
(
op_data
.
logical_shape
,
dim_partition_dict
):
for
logical_shape
,
dim_partition_dict_element
in
zip
(
op_data
.
logical_shape
,
dim_partition_dict
):
dim_size
=
len
(
logical_shape
)
dim_size
=
len
(
logical_shape
)
dim_partition_dict_element
=
convert_dim_partition_dict
(
dim_size
,
dim_partition_dict_element
)
dim_partition_dict_element
=
convert_dim_partition_dict
(
dim_size
,
dim_partition_dict_element
)
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
sharding_spec_element
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
logical_shape
,
entire_shape
=
logical_shape
,
dim_partition_dict
=
dim_partition_dict_element
)
dim_partition_dict
=
dim_partition_dict_element
)
sharding_spec
.
append
(
sharding_spec_element
)
else
:
else
:
assert
isinstance
(
assert
isinstance
(
op_data
.
data
,
torch
.
Tensor
op_data
.
data
,
torch
.
Tensor
...
@@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
...
@@ -193,18 +194,40 @@ class StrategyGenerator(ABC):
Args:
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
key (str): the name of the operation data defined by the generator.
"""
"""
op_data
=
self
.
op_data
[
key
]
op_data
=
self
.
op_data
[
key
]
sharded_shape
=
strategy
.
sharding_specs
[
op_data
].
get_sharded_shape_per_device
()
if
len
(
sharded_shape
)
==
0
:
def
_compute_size_in_bytes_helper
(
sharding_spec
,
meta_data
):
num_elements
=
1
sharded_shape
=
sharding_spec
.
get_sharded_shape_per_device
()
if
len
(
sharded_shape
)
==
0
:
num_elements
=
1
else
:
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
dtype
=
getattr
(
meta_data
,
'dtype'
)
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
num_elements
*
size_per_elem_bytes
if
isinstance
(
op_data
.
data
,
tuple
):
assert
isinstance
(
strategy
.
sharding_specs
[
op_data
],
list
),
\
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
total_bytes
=
0
for
index
,
sharding_spec
in
enumerate
(
strategy
.
sharding_specs
[
op_data
]):
meta_data
=
op_data
.
data
[
index
]
if
isinstance
(
meta_data
,
torch
.
Tensor
):
element_bytes
=
_compute_size_in_bytes_helper
(
sharding_spec
,
meta_data
)
else
:
# if meta_data is not a tensor, we count the memroy as 0
element_bytes
=
0
total_bytes
+=
element_bytes
else
:
else
:
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
if
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
dtype
=
self
.
op_data
[
key
].
data
.
dtype
total_bytes
=
_compute_size_in_bytes_helper
(
strategy
.
sharding_specs
[
op_data
],
op_data
.
data
)
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
else
:
return
num_elements
*
size_per_elem_bytes
# if op_data.data is not a tensor, we count the memroy as 0
total_bytes
=
0
return
total_bytes
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
def
generate
(
self
)
->
List
[
ShardingStrategy
]:
"""
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
View file @
05020e50
...
@@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator
...
@@ -10,6 +10,8 @@ from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__
=
[
'UnaryElementwiseHandler'
]
__all__
=
[
'UnaryElementwiseHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
to
)
@
operator_registry
.
register
(
torch
.
Tensor
.
type
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
class
UnaryElementwiseHandler
(
NodeHandler
):
class
UnaryElementwiseHandler
(
NodeHandler
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment