Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
461
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1688 additions
and
70 deletions
+1688
-70
colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
...ai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
+14
-4
colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
...i/auto_parallel/tensor_shard/node_handler/conv_handler.py
+4
-4
colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
...o_parallel/tensor_shard/node_handler/embedding_handler.py
+230
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
...rallel/tensor_shard/node_handler/experimental/__init__.py
+10
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
...tensor_shard/node_handler/experimental/permute_handler.py
+76
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
...nsor_shard/node_handler/experimental/reshape_generator.py
+299
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py
...l/tensor_shard/node_handler/experimental/split_handler.py
+63
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
...nsor_shard/node_handler/experimental/transpose_handler.py
+65
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
...el/tensor_shard/node_handler/experimental/view_handler.py
+53
-0
colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
...uto_parallel/tensor_shard/node_handler/getattr_handler.py
+34
-0
colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
...uto_parallel/tensor_shard/node_handler/getitem_handler.py
+1
-1
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
...auto_parallel/tensor_shard/node_handler/linear_handler.py
+65
-21
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
...auto_parallel/tensor_shard/node_handler/matmul_handler.py
+486
-0
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+143
-11
colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
...allel/tensor_shard/node_handler/normal_pooling_handler.py
+2
-2
colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
...auto_parallel/tensor_shard/node_handler/output_handler.py
+27
-14
colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
...parallel/tensor_shard/node_handler/placeholder_handler.py
+15
-5
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
...salai/auto_parallel/tensor_shard/node_handler/registry.py
+6
-1
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+40
-7
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
...uto_parallel/tensor_shard/node_handler/softmax_handler.py
+55
-0
No files found.
Too many changes to show.
To preserve performance only
461 of 461+
files are displayed.
Plain diff
Email patch
colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
View file @
e532679c
...
...
@@ -2,8 +2,10 @@ from typing import Dict, List, Union
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
recover_sharding_spec_for_broadcast_shape
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
,
ShapeConsistencyManager
from
..sharding_strategy
import
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
comm_actions_for_oprands
,
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
BatchedMatMulStrategyGenerator
,
StrategyGenerator
...
...
@@ -91,7 +93,15 @@ class AddBMMFunctionHandler(NodeHandler):
bias_physical_shape
=
bias_op_data
.
data
.
shape
bias_logical_shape
=
bias_op_data
.
logical_shape
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
bias_op_data
.
name
)
bias_sharding_spec
=
recover_sharding_spec_for_broadcast_shape
(
bias_sharding_spec
,
bias_logical_shape
,
bias_physical_shape
)
bias_sharding_spec
,
removed_dims
=
recover_sharding_spec_for_broadcast_shape
(
bias_sharding_spec
,
bias_logical_shape
,
bias_physical_shape
)
strategy
.
sharding_specs
[
bias_op_data
]
=
bias_sharding_spec
if
len
(
removed_dims
)
>
0
:
comm_action
=
comm_actions_for_oprands
(
node
=
self
.
node
,
removed_dims
=
removed_dims
,
op_data
=
bias_op_data
,
sharding_spec
=
bias_sharding_spec
)
strategy
.
communication_actions
[
bias_op_data
]
=
comm_action
return
strategy
colossalai/auto_parallel/tensor_shard/node_handler/conv_handler.py
View file @
e532679c
...
...
@@ -3,9 +3,9 @@ from typing import Dict, List
import
torch
import
torch.nn.functional
as
F
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
from
..utils
import
transpose_partition_dim
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
.node_handler
import
MetaInfoModuleHandler
,
MetaInfoNodeHandler
,
ModuleHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
ConvStrategyGenerator
,
StrategyGenerator
...
...
@@ -15,7 +15,7 @@ __all__ = ['ConvModuleHandler', 'ConvFunctionHandler']
@
operator_registry
.
register
(
torch
.
nn
.
Conv1d
)
@
operator_registry
.
register
(
torch
.
nn
.
Conv2d
)
@
operator_registry
.
register
(
torch
.
nn
.
Conv3d
)
class
ConvModuleHandler
(
ModuleHandler
):
class
ConvModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module.
"""
...
...
@@ -63,7 +63,7 @@ class ConvModuleHandler(ModuleHandler):
@
operator_registry
.
register
(
F
.
conv1d
)
@
operator_registry
.
register
(
F
.
conv2d
)
@
operator_registry
.
register
(
F
.
conv3d
)
class
ConvFunctionHandler
(
NodeHandler
):
class
ConvFunctionHandler
(
MetaInfo
NodeHandler
):
"""
A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/embedding_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.tensor_shard.utils
import
update_partition_dim
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.sharding_spec
import
ShardingNotDivisibleError
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
EmbeddingStrategyGenerator
,
StrategyGenerator
__all__
=
[
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
]
def
_convert_logical_sharding_to_physical_sharding_spec_for_embedding
(
strategy
:
ShardingStrategy
,
input_name
:
str
,
output_name
:
str
)
->
List
[
ShardingStrategy
]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output
of the embedding operation.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies
=
[]
# get operation data
input_op_data
=
strategy
.
get_op_data_by_name
(
input_name
)
output_op_data
=
strategy
.
get_op_data_by_name
(
output_name
)
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
# recover the last logical dimension to physical dimension
last_logical_output_dims
=
len
(
output_op_data
.
logical_shape
)
-
1
last_physical_output_dims
=
output_op_data
.
data
.
dim
()
-
1
# get logger for debug message
logger
=
get_dist_logger
()
# For the input of the embedding operation, it can be multi-dimensional. The sharding spec is only generated for
# logical 1D non-matrix dimension, the logical non-matrix dimension can belong to the 0th to Nth dimension of the
# physical input shape. Thus, we enumerate to get all possible cases.
if
input_sharding_spec
.
dim_partition_dict
:
# if bool(input_sharding_spec.dim_partition_dict), it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims
=
input_op_data
.
data
.
dim
()
for
i
in
range
(
num_input_dims
):
strategy_copy
=
strategy
.
clone
()
input_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
try
:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{
0
:
i
},
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
if
last_logical_output_dims
in
output_sharding_spec
.
dim_partition_dict
:
dim_mapping
=
{
0
:
i
,
last_logical_output_dims
:
last_physical_output_dims
}
else
:
dim_mapping
=
{
0
:
i
}
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
strategy_copy
.
name
=
f
'
{
strategy
.
name
}
_
{
i
}
'
sharding_strategies
.
append
(
strategy_copy
)
except
ShardingNotDivisibleError
as
e
:
logger
.
debug
(
f
'Errored occurred when converting the logical sharding spec to the physical one. Error details:
{
e
}
'
)
else
:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy
=
strategy
.
clone
()
input_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{},
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
if
last_logical_output_dims
in
output_sharding_spec
.
dim_partition_dict
:
dim_mapping
=
{
last_logical_output_dims
:
last_physical_output_dims
}
else
:
dim_mapping
=
{}
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
sharding_strategies
.
append
(
strategy_copy
)
return
sharding_strategies
@
operator_registry
.
register
(
torch
.
nn
.
Embedding
)
class
EmbeddingModuleHandler
(
ModuleHandler
):
"""
A EmbeddingModuleHandler which deals with the sharding strategies for nn.Embedding module.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
EmbeddingStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# In nn.Embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data
=
self
.
node
.
args
[
0
].
_meta_data
input_logical_shape
=
input_meta_data
.
view
(
-
1
).
shape
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
input_meta_data
,
logical_shape
=
input_logical_shape
)
physical_other_operand
=
OperationData
(
name
=
"weight"
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'weight'
])
# Same as input, in nn.Embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data
=
self
.
node
.
_meta_data
output_logical_shape
=
output_meta_data
.
view
(
-
1
,
output_meta_data
.
shape
[
-
1
]).
shape
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_meta_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies
=
_convert_logical_sharding_to_physical_sharding_spec_for_embedding
(
strategy
=
strategy
,
input_name
=
str
(
self
.
node
.
args
[
0
]),
output_name
=
str
(
self
.
node
))
return
strategies
@
operator_registry
.
register
(
F
.
embedding
)
class
EmbeddingFunctionHandler
(
NodeHandler
):
"""
A EmbeddingFunctionHandler which deals with the sharding strategies for F.embedding.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
EmbeddingStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# In F.embedding operation, all the dimensions of input will be treated as the batch dimension,
# and then the sharding spec will be generated based on the logical 1D tensor.
# After that, the logical sharding info will be enumerated among all the physical dimensions.
# Finally, the input will be transformed back to its original shape in self.post_process
input_meta_data
=
self
.
node
.
args
[
0
].
_meta_data
input_logical_shape
=
input_meta_data
.
view
(
-
1
).
shape
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
,
logical_shape
=
input_logical_shape
)
# check if the other operand is a parameter
if
isinstance
(
self
.
node
.
args
[
1
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
physical_other_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
data_type
,
data
=
self
.
node
.
args
[
1
].
_meta_data
)
# Same as input, in F.embedding operation, all the dimensions of output will be treated as
# (batch dimension, embedding dimension), and then the sharding spec will be generated based
# on the logical 2D tensor.
# After that, the logical sharding info of batch dimension will be enumerated among all the physical dimensions.
# Finally, the output will be transformed back to its original shape in self.post_process
output_meta_data
=
self
.
node
.
_meta_data
output_logical_shape
=
output_meta_data
.
view
(
-
1
,
output_meta_data
.
shape
[
-
1
]).
shape
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
,
logical_shape
=
output_logical_shape
,
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
"""
Convert the sharding spec from the logical shape to the physical shape.
"""
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at logical dim 0 to one of the first few dimensions of the input and output
strategies
=
_convert_logical_sharding_to_physical_sharding_spec_for_embedding
(
strategy
=
strategy
,
input_name
=
str
(
self
.
node
.
args
[
0
]),
output_name
=
str
(
self
.
node
))
return
strategies
colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
0 → 100644
View file @
e532679c
from
.permute_handler
import
PermuteHandler
from
.reshape_generator
import
PermuteGenerator
,
SplitGenerator
,
TransposeGenerator
,
ViewGenerator
from
.split_handler
import
SplitHandler
from
.transpose_handler
import
TransposeHandler
from
.view_handler
import
ViewHandler
__all__
=
[
'ViewGenerator'
,
'ViewHandler'
,
'PermuteGenerator'
,
'PermuteHandler'
,
'TransposeGenerator'
,
'TransposeGenerator'
,
'SplitHandler'
,
'SplitGenerator'
]
colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
PermuteGenerator
__all__
=
[
'PermuteHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
permute
)
class
PermuteHandler
(
NodeHandler
):
"""
A PermuteHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
PermuteGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
permute_dims
=
[]
if
self
.
node
.
op
==
'call_method'
:
# torch.Tensor.permute (input, *dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
permute_dims
.
append
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in permute node should be either type of Node or int.'
permute_dims
.
append
(
arg
)
else
:
# torch.permute (input, dims)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
(
tuple
,
list
)):
permute_dims
.
extend
(
arg
.
_meta_data
)
else
:
assert
isinstance
(
arg
,
(
tuple
,
list
)),
'The argument in permute node should be type of Node, Tuple[int] or List[int].'
permute_dims
.
extend
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
num_dims
):
# recover negative value to positive
if
permute_dims
[
i
]
<
0
:
permute_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'permute_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
permute_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"permute_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
0 → 100644
View file @
e532679c
import
copy
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_keep_sharding_status
,
detect_reshape_mapping
,
infer_output_dim_partition_dict
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'ReshapeGenerator'
,
'ViewGenerator'
,
'PermuteGenerator'
,
'TransposeGenerator'
,
'SplitGenerator'
]
class
ReshapeGenerator
(
FollowingStrategyGenerator
):
"""
ReshapeGenerator is the base class for all the reshape operation.
"""
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
compute_cost
=
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
return
super
().
collate_strategies
()
class
ViewGenerator
(
ReshapeGenerator
):
"""
ViewGenerator deals with the sharding strategies of view op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
origin_shape
=
self
.
op_data
[
'input'
].
data
.
shape
tgt_shape
=
self
.
op_data
[
'tgt_shape'
].
data
reshape_mapping_dict
=
detect_reshape_mapping
(
origin_shape
,
tgt_shape
)
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
keep_sharding_status
=
check_keep_sharding_status
(
dim_partition_dict_for_input
,
reshape_mapping_dict
)
if
keep_sharding_status
:
dim_partition_dict_for_output
=
infer_output_dim_partition_dict
(
dim_partition_dict_for_input
,
reshape_mapping_dict
)
else
:
dim_partition_dict_for_output
=
{}
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
if
keep_sharding_status
:
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
else
:
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
-> FULLY REPLICATED_
{
index
}
'
# add comm action for converting input to fully replicated
total_mesh_dim_list
=
[]
for
mesh_dim_list
in
dim_partition_dict_for_input
.
values
():
total_mesh_dim_list
.
extend
(
mesh_dim_list
)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
total_mesh_dim_list
)
==
1
:
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim
=
list
(
dim_partition_dict_for_input
.
keys
())[
0
]
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
# it will gather the input through gather_dim during forward phase.
input_comm_action
.
comm_spec
.
gather_dim
=
shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action
.
comm_spec
.
shard_dim
=
shard_dim
elif
len
(
total_mesh_dim_list
)
>=
2
:
source_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
source_spec
.
entire_shape
,
dim_partition_dict
=
{})
comm_spec
=
{
'src_spec'
:
source_spec
,
'tgt_spec'
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
else
:
input_comm_action
=
None
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
class
PermuteGenerator
(
ReshapeGenerator
):
"""
PermuteGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
permute_dims
=
self
.
op_data
[
'permute_dims'
].
data
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
for
dim_index
,
permute_dim
in
enumerate
(
permute_dims
):
if
permute_dim
in
dim_partition_dict_for_input
:
dim_partition_dict_for_output
[
dim_index
]
=
dim_partition_dict_for_input
[
permute_dim
]
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
class
TransposeGenerator
(
ReshapeGenerator
):
"""
TransposeGenerator deals with the sharding strategies of permute op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
transpose_dims
=
self
.
op_data
[
'transpose_dims'
].
data
dim_0
=
transpose_dims
[
0
]
dim_1
=
transpose_dims
[
1
]
for
dim
,
sharded_dims
in
dim_partition_dict_for_input
.
items
():
if
dim
==
dim_0
:
dim_partition_dict_for_output
[
dim_1
]
=
dim_partition_dict_for_input
[
dim_0
]
elif
dim
==
dim_1
:
dim_partition_dict_for_output
[
dim_0
]
=
dim_partition_dict_for_input
[
dim_1
]
else
:
dim_partition_dict_for_output
[
dim
]
=
sharded_dims
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
class
SplitGenerator
(
ReshapeGenerator
):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
recover_dims
=
None
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
split_size
,
split_dim
=
self
.
op_data
[
'split_info'
].
data
if
split_dim
in
dim_partition_dict_for_input
:
recover_dims
=
dim_partition_dict_for_input
.
pop
(
split_dim
)
dim_partition_dict_for_output
=
[
copy
.
deepcopy
(
dim_partition_dict_for_input
)
for
_
in
range
(
len
(
self
.
op_data
[
"output"
].
data
))
]
assert
len
(
dim_partition_dict_for_output
)
>=
2
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
# add comm action if the input need to be recovered to replica in the split dimension.
if
recover_dims
:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
recover_dims
)
==
1
:
recover_dims
=
recover_dims
[
0
]
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
recover_dims
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
# it will gather the input through gather_dim during forward phase.
input_comm_action
.
comm_spec
.
gather_dim
=
split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action
.
comm_spec
.
shard_dim
=
split_dim
elif
len
(
recover_dims
)
>=
2
:
# original sharding spec
source_spec
=
input_sharding_spec
# target sharding spec
target_spec
=
sharding_spec_mapping
[
"input"
]
comm_spec
=
{
'src_spec'
:
source_spec
,
'tgt_spec'
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
else
:
input_comm_action
=
None
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
SplitGenerator
__all__
=
[
'SplitHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
split
)
class
SplitHandler
(
NodeHandler
):
"""
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
SplitGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
split_size
=
self
.
node
.
args
[
1
]
if
len
(
self
.
node
.
args
)
==
3
:
# (input, split_size, split_dim)
split_dim
=
self
.
node
.
args
[
2
]
else
:
if
self
.
node
.
kwargs
:
split_dim
=
self
.
node
.
kwargs
[
'dim'
]
else
:
split_dim
=
0
num_dims
=
self
.
node
.
args
[
0
].
_meta_data
.
dim
()
# recover negative value to positive
if
split_dim
<
0
:
split_dim
+=
num_dims
split_info
=
(
split_size
,
split_dim
)
physical_shape_operand
=
OperationData
(
name
=
'split_info'
,
type
=
OperationDataType
.
ARG
,
data
=
split_info
)
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"split_info"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
TransposeGenerator
__all__
=
[
'TransposeHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
transpose
)
class
TransposeHandler
(
NodeHandler
):
"""
A TransposeHandler which deals with the sharding strategies for torch.permute or torch.transpose.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
TransposeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
transpose_dims
=
[]
# torch.transpose (input, dim0, dim1)
for
arg
in
self
.
node
.
args
:
if
isinstance
(
arg
,
torch
.
fx
.
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
transpose_dims
.
append
(
arg
.
_meta_data
)
else
:
transpose_dims
.
append
(
arg
)
num_dims
=
self
.
node
.
_meta_data
.
dim
()
for
i
in
range
(
2
):
# recover negative value to positive
if
transpose_dims
[
i
]
<
0
:
transpose_dims
[
i
]
+=
num_dims
physical_shape_operand
=
OperationData
(
name
=
'transpose_dims'
,
type
=
OperationDataType
.
ARG
,
data
=
list
(
transpose_dims
))
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"transpose_dims"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
import
torch
from
...sharding_strategy
import
OperationData
,
OperationDataType
from
..node_handler
import
NodeHandler
from
..registry
import
operator_registry
from
..strategy
import
StrategyGenerator
from
.reshape_generator
import
ViewGenerator
__all__
=
[
'ViewHandler'
]
@
operator_registry
.
register
(
torch
.
Tensor
.
reshape
)
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
view
)
class
ViewHandler
(
NodeHandler
):
"""
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
ViewGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
target_shape
=
self
.
node
.
_meta_data
.
shape
physical_shape_operand
=
OperationData
(
name
=
'tgt_shape'
,
type
=
OperationDataType
.
ARG
,
data
=
target_shape
)
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"tgt_shape"
:
physical_shape_operand
,
"output"
:
physical_output_operand
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/getattr_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.strategy
import
GetattrGenerator
,
StrategyGenerator
__all__
=
[
'GetattrHandler'
]
class
GetattrHandler
(
NodeHandler
):
"""
A GetattrHandler which deals with the sharding strategies for Getattr Node.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
GetattrGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# There are only two possible types for get_attr node:
# 1. torch.Tensor(torch.nn.Parameters or torch.nn.Buffers)
# 2. torch.nn.Module
# temporarily, we just support first case in Tracer, so we don't have to worry about
# issue related to the node._meta_data type.
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"output"
:
physical_output
}
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/getitem_handler.py
View file @
e532679c
...
...
@@ -6,7 +6,7 @@ import torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
(
StrategyGenerator
,
TensorStrategyGenerator
,
TensorTupleStrategyGenerator
)
from
.strategy
import
StrategyGenerator
,
TensorStrategyGenerator
,
TensorTupleStrategyGenerator
__all__
=
[
'GetItemHandler'
]
...
...
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
View file @
e532679c
...
...
@@ -3,12 +3,16 @@ from typing import Dict, List, Union
import
torch
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.tensor_shard.utils
import
transpose_partition_dim
,
update_partition_dim
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_sharding_spec_validity
,
transpose_partition_dim
,
update_partition_dim
,
)
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.sharding_spec
import
ShardingNotDivisibleError
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
from
.node_handler
import
MetaInfoModuleHandler
,
MetaInfoNodeHandler
,
ModuleHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator
...
...
@@ -28,9 +32,11 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
# switch the dimensions of the transposed weight
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
weight_name
)
op_data
=
strategy
.
get_op_data_by_name
(
weight_name
)
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
,
\
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
transpose_partition_dim
(
sharding_spec
,
0
,
-
1
)
assert
op_data
.
logical_shape
[
0
]
==
op_data
.
data
.
shape
[
1
]
and
\
op_data
.
logical_shape
[
1
]
==
op_data
.
data
.
shape
[
0
],
\
"Expected the logical shape of the linear operator's weight is equal to transposed physical shape"
dim_size
=
len
(
op_data
.
logical_shape
)
transpose_partition_dim
(
sharding_spec
,
0
,
dim_size
-
1
)
return
strategy
...
...
@@ -54,6 +60,23 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_op_data
=
strategy
.
get_op_data_by_name
(
input_name
)
output_op_data
=
strategy
.
get_op_data_by_name
(
output_name
)
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
# recover the last logical dimension to physical dimension
last_logical_input_dims
=
len
(
input_op_data
.
logical_shape
)
-
1
last_logical_output_dims
=
len
(
output_op_data
.
logical_shape
)
-
1
last_physical_input_dims
=
input_op_data
.
data
.
dim
()
-
1
last_physical_output_dims
=
output_op_data
.
data
.
dim
()
-
1
if
last_logical_input_dims
in
input_sharding_spec
.
dim_partition_dict
:
input_last_dim_mapping
=
{
last_logical_input_dims
:
last_physical_input_dims
}
else
:
input_last_dim_mapping
=
{}
if
last_logical_output_dims
in
output_sharding_spec
.
dim_partition_dict
:
output_last_dim_mapping
=
{
last_logical_output_dims
:
last_physical_output_dims
}
else
:
output_last_dim_mapping
=
{}
# get logger for debug message
logger
=
get_dist_logger
()
...
...
@@ -73,14 +96,21 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
try
:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping
=
{
0
:
i
}
input_dim_mapping
.
update
(
input_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{
0
:
i
}
,
dim_mapping
=
input_dim_mapping
,
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
output_dim_mapping
=
{
0
:
i
}
output_dim_mapping
.
update
(
output_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{
0
:
i
}
,
dim_mapping
=
output_dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
strategy_copy
.
name
=
f
'
{
strategy
.
name
}
_
{
i
}
'
sharding_strategies
.
append
(
strategy_copy
)
except
ShardingNotDivisibleError
as
e
:
logger
.
debug
(
...
...
@@ -95,12 +125,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping
=
{}
input_dim_mapping
.
update
(
input_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{}
,
dim_mapping
=
input_dim_mapping
,
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
output_dim_mapping
=
{}
output_dim_mapping
.
update
(
output_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{}
,
dim_mapping
=
output_dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
sharding_strategies
.
append
(
strategy_copy
)
...
...
@@ -108,7 +143,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
@
operator_registry
.
register
(
torch
.
nn
.
Linear
)
class
LinearModuleHandler
(
ModuleHandler
):
class
LinearModuleHandler
(
MetaInfo
ModuleHandler
):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
...
...
@@ -116,7 +151,8 @@ class LinearModuleHandler(ModuleHandler):
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
,
linear_projection_type
=
'linear'
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
...
...
@@ -167,15 +203,16 @@ class LinearModuleHandler(ModuleHandler):
@
operator_registry
.
register
(
F
.
linear
)
class
LinearFunctionHandler
(
NodeHandler
):
class
LinearFunctionHandler
(
MetaInfo
NodeHandler
):
"""
A Linear
Module
Handler which deals with the sharding strategies for
nn
.Linear
module
.
A Linear
Function
Handler which deals with the sharding strategies for
F
.Linear.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
,
linear_projection_type
=
'linear'
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
...
...
@@ -198,27 +235,34 @@ class LinearFunctionHandler(NodeHandler):
type
=
data_type
,
data
=
self
.
node
.
args
[
1
].
_meta_data
,
logical_shape
=
self
.
node
.
args
[
1
].
_meta_data
.
shape
[::
-
1
])
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
output_meta_data
=
self
.
node
.
_meta_data
output_logical_shape
=
output_meta_data
.
view
(
-
1
,
output_meta_data
.
shape
[
-
1
]).
shape
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
,
logical_shape
=
output_logical_shape
,
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
self
.
node
.
args
[
2
]
is
not
None
:
if
'bias'
in
self
.
node
.
kwargs
and
self
.
node
.
kwargs
[
'bias'
]
is
not
None
:
# check if the other operand is a parameter
if
isinstance
(
self
.
node
.
args
[
2
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
if
isinstance
(
self
.
node
.
kw
args
[
"bias"
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
2
]),
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
kw
args
[
"bias"
]),
type
=
data_type
,
data
=
self
.
node
.
args
[
2
].
_meta_data
)
data
=
self
.
node
.
kw
args
[
"bias"
].
_meta_data
)
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
# switch the dimensions of the transposed weight
strategy
=
_update_sharding_spec_for_transposed_weight_for_linear
(
strategy
=
strategy
,
weight_name
=
str
(
self
.
node
.
args
[
1
]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
...
...
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
0 → 100644
View file @
e532679c
import
operator
from
abc
import
ABC
,
abstractmethod
from
copy
import
deepcopy
from
enum
import
Enum
from
functools
import
reduce
from
typing
import
Dict
,
List
,
Union
import
torch
from
colossalai.auto_parallel.tensor_shard.utils.broadcast
import
(
BroadcastType
,
get_broadcast_dim_info
,
get_broadcast_shape
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
(
BatchedMatMulStrategyGenerator
,
DotProductStrategyGenerator
,
LinearProjectionStrategyGenerator
,
MatVecStrategyGenerator
,
StrategyGenerator
,
)
class
MatMulType
(
Enum
):
"""
The MatMulType is categorized into 4 types based on the reference of torch.matmul
in https://pytorch.org/docs/stable/generated/torch.matmul.html.
DOT: dot product, both tensors are 1D, these two tensors need to have the same number of elements
MM: matrix-matrix product, both tensors are 2D or the 1st tensor is 1D and the 2nd tensor is 2D
MV: matrix-vector product: the 1st tensor is 2D and the 2nd tensor is 1D
BMM: batched matrix-matrix multiplication, one tensor is at least 1D and the other is at least 3D
"""
DOT
=
0
MM
=
1
MV
=
2
BMM
=
3
def
get_matmul_type
(
input_dim
:
int
,
other_dim
:
int
):
"""
Determine which type of matmul operation should be executed for the given tensor dimensions.
Args:
input_dim (int): the number of dimensions for the input tenosr
other_dim (int): the number of dimensions for the other tenosr
"""
if
input_dim
==
1
and
other_dim
==
1
:
matmul_type
=
MatMulType
.
DOT
elif
input_dim
in
[
1
,
2
]
and
other_dim
==
2
:
matmul_type
=
MatMulType
.
MM
elif
input_dim
==
2
and
other_dim
==
1
:
matmul_type
=
MatMulType
.
MV
elif
input_dim
>=
1
and
other_dim
>=
1
and
(
input_dim
>
2
or
other_dim
>
2
):
matmul_type
=
MatMulType
.
BMM
else
:
raise
ValueError
(
f
"The input and other tensors are of
{
input_dim
}
and
{
other_dim
}
which cannot used to execute matmul operation"
)
return
matmul_type
class
BmmTransform
(
ABC
):
"""
BmmTransform is an abstraction of the shape conversion between logical and physical operation data
during the strategy generation.
"""
@
abstractmethod
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
pass
@
abstractmethod
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
pass
class
Padder
(
BmmTransform
):
"""
Add padding to the matrix dimensions for batched matrix multiplication.
"""
def
__init__
(
self
)
->
None
:
# keep the padding dim, op_name -> padded_dim
self
.
padded_dim_mapping
=
{}
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
deepcopy
(
shape_mapping
)
input_shape
=
mapping_copy
[
'input'
]
other_shape
=
mapping_copy
[
'other'
]
if
len
(
input_shape
)
==
1
:
# if the input is a 1D tensor, 1 is prepended to its shape
# and it will be removed afterwards
input_shape
.
insert
(
0
,
1
)
self
.
padded_dim_mapping
[
'input'
]
=
-
2
self
.
padded_dim_mapping
[
'output'
]
=
-
2
elif
len
(
other_shape
)
==
1
:
# if the other is a 1D tensor, 1 is appended to its shape
# and it will be removed afterwards
other_shape
=
other_shape
.
append
(
1
)
self
.
padded_dim_mapping
[
'other'
]
=
-
1
self
.
padded_dim_mapping
[
'output'
]
=
-
1
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
input_op_data
=
op_data_mapping
[
'input'
]
other_op_data
=
op_data_mapping
[
'other'
]
def
_remove_padded_dim
(
key
,
strategy
):
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
tensor_shape
=
list
(
sharding_spec
.
entire_shape
)
dim_partition_list
=
[
None
]
*
len
(
tensor_shape
)
# padded dim is a negative number as the padded dim must be a matrix dim
padded_dim
=
self
.
padded_dim_mapping
[
key
]
# compute the new dim partition
for
tensor_dim
,
mesh_dims
in
sharding_spec
.
dim_partition_dict
.
items
():
dim_partition_list
[
tensor_dim
]
=
mesh_dims
dim_partition_list
.
pop
(
padded_dim
)
unpadded_dim_partition_list
=
{
k
:
v
for
k
,
v
in
enumerate
(
dim_partition_list
)
if
v
is
not
None
}
# compute unpadded tensor shape
tensor_shape
.
pop
(
padded_dim
)
assert
tensor_shape
==
list
(
op_data
.
data
.
shape
),
f
'
{
tensor_shape
}
vs
{
list
(
op_data
.
data
.
shape
)
}
'
# update sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
tensor_shape
,
unpadded_dim_partition_list
)
# enumerate all sharding strategies
strategies
=
[]
try
:
strategy_copy
=
strategy
.
clone
()
# only one of input and other will be padded
if
'input'
in
self
.
padded_dim_mapping
:
_remove_padded_dim
(
'input'
,
strategy_copy
)
_remove_padded_dim
(
'output'
,
strategy_copy
)
elif
'other'
in
self
.
padded_dim_mapping
:
_remove_padded_dim
(
'other'
,
strategy_copy
)
_remove_padded_dim
(
'output'
,
strategy_copy
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
pass
return
strategies
class
Broadcaster
(
BmmTransform
):
"""
Broadcast the non-matrix dimensions for batched matrix multiplication.
"""
def
__init__
(
self
)
->
None
:
self
.
broadcast_dim_info
=
{}
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
shape_mapping
.
copy
()
# get shapes
input_shape
=
mapping_copy
[
'input'
]
other_shape
=
mapping_copy
[
'other'
]
# sanity check
assert
len
(
input_shape
)
>
1
and
len
(
other_shape
)
>
1
# broadcast the batch dim and record
bcast_non_matrix_dims
=
get_broadcast_shape
(
input_shape
[:
-
2
],
other_shape
[:
-
2
])
# store the broadcast dim info
input_broadcast_dim_info
=
get_broadcast_dim_info
(
bcast_non_matrix_dims
,
input_shape
[:
-
2
])
other_broadcast_dim_info
=
get_broadcast_dim_info
(
bcast_non_matrix_dims
,
other_shape
[:
-
2
])
self
.
broadcast_dim_info
[
'input'
]
=
input_broadcast_dim_info
self
.
broadcast_dim_info
[
'other'
]
=
other_broadcast_dim_info
# create the full logical shape
input_shape
=
bcast_non_matrix_dims
+
input_shape
[
-
2
:]
other_shape
=
bcast_non_matrix_dims
+
other_shape
[
-
2
:]
assert
len
(
input_shape
)
==
len
(
other_shape
)
mapping_copy
[
'input'
]
=
input_shape
mapping_copy
[
'other'
]
=
other_shape
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
# remove sharding on the broadcast dim
def
_remove_sharding_on_broadcast_dim
(
key
,
strategy
):
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
tensor_shape
=
list
(
sharding_spec
.
entire_shape
)
for
dim_idx
,
broadcast_type
in
self
.
broadcast_dim_info
[
key
].
items
():
if
broadcast_type
==
BroadcastType
.
MULTIPLE
:
# if the dim is originally 1 and multiplied during broadcast
# we set its sharding to R
# e.g. [1, 2, 4] x [4, 4, 8] -> [4, 2, 8]
# the dim 0 of [1, 2, 4] is multiplied to 4
tensor_shape
[
dim_idx
]
=
1
elif
broadcast_type
==
BroadcastType
.
PADDDING
:
# if the dim is padded
# we remove its sharding
tensor_shape
[
dim_idx
]
=
None
tensor_shape_before_broadcast
=
[
dim
for
dim
in
tensor_shape
if
dim
is
not
None
]
physical_sharding_spec
,
removed_dims
=
recover_sharding_spec_for_broadcast_shape
(
logical_sharding_spec
=
sharding_spec
,
logical_shape
=
sharding_spec
.
entire_shape
,
physical_shape
=
tensor_shape_before_broadcast
)
strategy
.
sharding_specs
[
op_data
]
=
physical_sharding_spec
# enumerate all sharding strategies
strategies
=
[]
try
:
strategy_copy
=
strategy
.
clone
()
_remove_sharding_on_broadcast_dim
(
'input'
,
strategy_copy
)
_remove_sharding_on_broadcast_dim
(
'other'
,
strategy_copy
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
pass
return
strategies
class
Viewer
(
BmmTransform
):
"""
Change the shape of the tensor from N-D to 3D
"""
def
__init__
(
self
)
->
None
:
self
.
batch_dims_before_view
=
None
def
apply
(
self
,
shape_mapping
:
Dict
[
str
,
List
[
int
]]):
mapping_copy
=
shape_mapping
.
copy
()
self
.
batch_dims_before_view
=
list
(
mapping_copy
[
'input'
][:
-
2
])
# get shapes
input_shape
=
shape_mapping
[
'input'
]
other_shape
=
shape_mapping
[
'other'
]
# view to 3d tensor
assert
len
(
input_shape
)
>=
3
and
len
(
other_shape
)
>=
3
input_shape
=
[
reduce
(
operator
.
mul
,
input_shape
[:
-
2
])]
+
input_shape
[
-
2
:]
other_shape
=
[
reduce
(
operator
.
mul
,
other_shape
[:
-
2
])]
+
other_shape
[
-
2
:]
output_shape
=
input_shape
[:
2
]
+
other_shape
[
2
:]
mapping_copy
[
'input'
]
=
input_shape
mapping_copy
[
'other'
]
=
other_shape
mapping_copy
[
'output'
]
=
output_shape
return
mapping_copy
def
recover
(
self
,
op_data_mapping
:
Dict
[
str
,
OperationData
],
strategy
:
ShardingStrategy
):
# get operation data
def
_update_sharding_spec
(
key
,
strategy
,
physical_batch_dim
):
"""
Map the logical batch dim to the physical batch dim
"""
op_data
=
op_data_mapping
[
key
]
sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
op_data
.
name
)
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
entire_shape
=
sharding_spec
.
entire_shape
# upddate the dimension index for the matrix dimensions
if
2
in
dim_partition_dict
:
dim_partition_dict
[
len
(
self
.
batch_dims_before_view
)
+
1
]
=
dim_partition_dict
.
pop
(
2
)
if
1
in
dim_partition_dict
:
dim_partition_dict
[
len
(
self
.
batch_dims_before_view
)]
=
dim_partition_dict
.
pop
(
1
)
# map the logical batch dim to phyiscal batch dim
if
0
in
dim_partition_dict
:
batch_dim_shard
=
dim_partition_dict
.
pop
(
0
)
dim_partition_dict
[
physical_batch_dim
]
=
batch_dim_shard
# the new shape will be the batch dims + the last 2 matrix dims
shape_before_view
=
self
.
batch_dims_before_view
+
list
(
entire_shape
[
-
2
:])
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
shape_before_view
,
dim_partition_dict
)
num_batch_dim_before_view
=
len
(
self
.
batch_dims_before_view
)
# enumerate all sharding strategies
strategies
=
[]
for
i
in
range
(
num_batch_dim_before_view
):
# create a new strategy
strategy_copy
=
strategy
.
clone
()
try
:
_update_sharding_spec
(
'input'
,
strategy_copy
,
i
)
_update_sharding_spec
(
'other'
,
strategy_copy
,
i
)
_update_sharding_spec
(
'output'
,
strategy_copy
,
i
)
strategies
.
append
(
strategy_copy
)
except
ShardingSpecException
as
e
:
continue
return
strategies
def
_get_bmm_logical_shape
(
input_shape
,
other_shape
,
transforms
):
"""
Compute the logical shapes for BMM operation. BMM has a general representation
[b, i, k] = [b, i, j] x [b, j, k]
The dimension b is called non-matrix (batch) dimension and the remaining dimensions are called matrix dimensions
The logical shape for the bmm operands will undergo three stages
1. append/prepend the 1 to the 1D tensor if there is any
2. broadcast the non-matrix dimensions
3. reshape to 3 dimensions
"""
shape_mapping
=
{
'input'
:
input_shape
,
'other'
:
other_shape
}
for
transform
in
transforms
:
shape_mapping
=
transform
.
apply
(
shape_mapping
)
input_shape
=
shape_mapping
.
get
(
'input'
,
None
)
other_shape
=
shape_mapping
.
get
(
'other'
,
None
)
output_shape
=
shape_mapping
.
get
(
'output'
,
None
)
return
input_shape
,
other_shape
,
output_shape
@
operator_registry
.
register
(
torch
.
matmul
)
@
operator_registry
.
register
(
torch
.
Tensor
.
matmul
)
class
MatMulHandler
(
NodeHandler
):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
the operands.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# check which type of operation this matmul will call
self
.
input_meta_data
=
self
.
node
.
args
[
0
].
_meta_data
self
.
other_meta_data
=
self
.
node
.
args
[
1
].
_meta_data
self
.
output_meta_data
=
self
.
node
.
_meta_data
input_dim
=
self
.
input_meta_data
.
dim
()
other_dim
=
self
.
other_meta_data
.
dim
()
self
.
matmul_type
=
get_matmul_type
(
input_dim
,
other_dim
)
if
self
.
matmul_type
==
MatMulType
.
BMM
:
# bmm operation can possibly involve padding, broadcasting and view
# these transforms will be used to create logical shape and
# recover physical sharding spec
self
.
transforms
=
[
Padder
(),
Broadcaster
(),
Viewer
()]
else
:
self
.
transforms
=
None
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
generators
=
[]
op_data_mapping
=
self
.
get_operation_data_mapping
()
if
self
.
matmul_type
==
MatMulType
.
BMM
:
generators
.
append
(
BatchedMatMulStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
DOT
:
generators
.
append
(
DotProductStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
MV
:
generators
.
append
(
MatVecStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
elif
self
.
matmul_type
==
MatMulType
.
MM
:
generators
.
append
(
LinearProjectionStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
,
linear_projection_type
=
'linear'
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
logical_shape_func
=
{
MatMulType
.
DOT
:
self
.
_get_logical_shape_for_dot
,
MatMulType
.
MM
:
self
.
_get_logical_shape_for_mm
,
MatMulType
.
MV
:
self
.
_get_logical_shape_for_mv
,
MatMulType
.
BMM
:
self
.
_get_logical_shape_for_bmm
}
logical_shapes
=
logical_shape_func
[
self
.
matmul_type
]()
op_data_mapping
=
self
.
_get_op_data_mapping
(
*
logical_shapes
)
return
op_data_mapping
def
_get_op_data_mapping
(
self
,
input_logical_shape
,
other_logical_shape
,
output_logical_shape
):
# convert list to torch.Size
if
input_logical_shape
:
input_logical_shape
=
torch
.
Size
(
input_logical_shape
)
if
other_logical_shape
:
other_logical_shape
=
torch
.
Size
(
other_logical_shape
)
if
output_logical_shape
:
output_logical_shape
=
torch
.
Size
(
output_logical_shape
)
# create op data
input_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
input_meta_data
,
logical_shape
=
input_logical_shape
)
other_op_data
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
other_meta_data
,
logical_shape
=
other_logical_shape
)
output_op_data
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
output_meta_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
'input'
:
input_op_data
,
'other'
:
other_op_data
,
'output'
:
output_op_data
}
return
mapping
def
_get_logical_shape_for_dot
(
self
):
"""
The operands for the dot operation have the same logical shape as the physical shape
"""
return
None
,
None
,
None
def
_get_logical_shape_for_mm
(
self
):
"""
We need to handle the input tensor for a matrix-matrix multiplcation as the input
tensor can be a 1D or 2D tensor. If it is a 1D tensor, 1 will be prepended to its shape
(e.g. [4] -> [1, 4]).
"""
if
self
.
input_meta_data
.
dim
()
==
1
:
input_logical_shape
=
[
1
]
+
list
(
self
.
input_meta_data
.
shape
)
input_logical_shape
=
torch
.
Size
(
input_logical_shape
)
else
:
input_logical_shape
=
None
return
input_logical_shape
,
None
,
None
def
_get_logical_shape_for_mv
(
self
):
"""
No broadcasting or dim insertion occurs for matrix-vector operation.
"""
return
None
,
None
,
None
def
_get_logical_shape_for_bmm
(
self
):
input_physical_shape
=
list
(
self
.
input_meta_data
.
shape
)
other_physical_shape
=
list
(
self
.
other_meta_data
.
shape
)
return
_get_bmm_logical_shape
(
input_physical_shape
,
other_physical_shape
,
self
.
transforms
)
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
if
self
.
matmul_type
in
[
MatMulType
.
DOT
,
MatMulType
.
MV
]:
return
strategy
elif
self
.
matmul_type
==
MatMulType
.
MM
:
if
self
.
input_meta_data
.
dim
()
==
1
:
# if a 1 is prepended to the input shape (this occurs when input is a 1D tensor)
# we need to remove that dim
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
str
(
self
.
node
.
args
[
0
]))
input_physical_shape
=
self
.
node
.
args
[
0
].
_meta_data
.
shape
dim_partition_dict
=
input_sharding_spec
.
dim_partition_dict
# remove the partitioning in the dim 0
if
0
in
dim_partition_dict
:
dim_partition_dict
.
pop
(
0
,
None
)
# move the partitioning in dim 1 to dim 0
if
-
1
in
dim_partition_dict
:
shard
=
dim_partition_dict
.
pop
(
-
1
)
dim_partition_dict
[
0
]
=
shard
if
1
in
dim_partition_dict
:
shard
=
dim_partition_dict
.
pop
(
1
)
dim_partition_dict
[
0
]
=
shard
# re-init the sharding spec
input_sharding_spec
.
__init__
(
input_sharding_spec
.
device_mesh
,
entire_shape
=
input_physical_shape
,
dim_partition_dict
=
dim_partition_dict
)
return
strategy
else
:
return
strategy
elif
self
.
matmul_type
==
MatMulType
.
BMM
:
op_data_mapping
=
self
.
get_operation_data_mapping
()
strategies
=
[
strategy
]
# recover the physical sharding spec
for
transform
in
self
.
transforms
[::
-
1
]:
recovered_stragies
=
[]
for
strategy_
in
strategies
:
output
=
transform
.
recover
(
op_data_mapping
,
strategy_
)
if
isinstance
(
output
,
ShardingStrategy
):
recovered_stragies
.
append
(
output
)
elif
isinstance
(
output
,
(
list
,
tuple
)):
recovered_stragies
.
extend
(
output
)
else
:
raise
TypeError
(
f
"Found unexpected output type
{
type
(
output
)
}
from the recover method of BmmTransform"
)
strategies
=
recovered_stragies
return
strategies
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
e532679c
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
,
meta_register
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingSpec
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
...
...
@@ -49,7 +52,16 @@ class NodeHandler(ABC):
for
node
in
self
.
predecessor_node
:
node_name
=
str
(
node
)
# get the current sharding spec generated by this node handler
# we will not compute the resharding costs for the node not counted in the strategy.
# And the node with tuple or list output need to be handled below.
node_in_strategy
=
[
op_data
.
name
for
op_data
in
strategy
.
sharding_specs
.
keys
()]
if
str
(
node
)
not
in
node_in_strategy
:
continue
op_data
=
strategy
.
get_op_data_by_name
(
node_name
)
current_sharding_spec
=
strategy
.
sharding_specs
[
op_data
]
# get the sharding specs for this node generated
# in its own node handler
assert
hasattr
(
node
,
'strategies_vector'
),
\
...
...
@@ -59,27 +71,83 @@ class NodeHandler(ABC):
prev_strategy
.
get_sharding_spec_by_name
(
node_name
)
for
prev_strategy
in
prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data
=
strategy
.
get_op_data_by_name
(
node_name
)
current_sharding_spec
=
strategy
.
sharding_specs
[
op_data
]
# create data structrure to store costs
if
op_data
not
in
resharding_costs
:
if
node
not
in
resharding_costs
:
resharding_costs
[
node
]
=
[]
def
_compute_resharding_cost
(
prev_sharding_spec
:
Union
[
ShardingSpec
,
List
[
ShardingSpec
]],
current_sharding_spec
:
Union
[
ShardingSpec
,
List
[
ShardingSpec
]],
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
Tuple
[
torch
.
Tensor
]])
->
TrainCycleItem
:
"""
This is a helper function to compute the resharding cost for a specific strategy of a node.
"""
if
prev_sharding_spec
is
None
:
return
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
elif
isinstance
(
prev_sharding_spec
,
ShardingSpec
):
if
isinstance
(
data
,
torch
.
Tensor
):
dtype
=
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
_
,
_
,
consistency_cost
=
shape_consistency_manager
.
shape_consistency
(
prev_sharding_spec
,
current_sharding_spec
)
resharding_cost
=
TrainCycleItem
(
fwd
=
consistency_cost
[
"forward"
]
*
size_per_elem_bytes
,
bwd
=
consistency_cost
[
"backward"
]
*
size_per_elem_bytes
,
total
=
consistency_cost
[
"total"
]
*
size_per_elem_bytes
)
return
resharding_cost
else
:
# This raise is used to check if we have missed any type of data.
# It could be merged into Parameter branch, which means we won't handle
# non-tensor arguments.
raise
ValueError
(
f
'Unsupported data type
{
type
(
data
)
}
'
)
else
:
assert
isinstance
(
prev_sharding_spec
,
(
tuple
,
list
)),
\
f
'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec],
\
or Tuple[ShardingSpec], but got
{
type
(
prev_sharding_spec
)
}
'
fwd_cost
=
0
bwd_cost
=
0
total_cost
=
0
for
index
,
(
prev_sharding_spec_item
,
current_sharding_spec_item
)
in
enumerate
(
zip
(
prev_sharding_spec
,
current_sharding_spec
)):
item_cost
=
_compute_resharding_cost
(
prev_sharding_spec_item
,
current_sharding_spec_item
,
data
[
index
])
fwd_cost
+=
item_cost
.
fwd
bwd_cost
+=
item_cost
.
bwd
total_cost
+=
item_cost
.
total
resharding_cost
=
TrainCycleItem
(
fwd
=
fwd_cost
,
bwd
=
bwd_cost
,
total
=
total_cost
)
return
resharding_cost
# for each sharding spec generated by the predecessor's node handler
# compute the resharding cost to switch to the sharding spec generated
# by the current node handler
for
prev_sharding_spec
in
prev_sharding_specs
:
_
,
_
,
resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
prev_sharding_spec
,
current_sharding_spec
)
resharding_cost
=
TrainCycleItem
(
fwd
=
resharding_cost
[
"forward"
],
bwd
=
resharding_cost
[
"backward"
],
total
=
resharding_cost
[
"total"
])
resharding_cost
=
_compute_resharding_cost
(
prev_sharding_spec
,
current_sharding_spec
,
op_data
.
data
)
resharding_costs
[
node
].
append
(
resharding_cost
)
strategy
.
resharding_costs
=
resharding_costs
return
strategy
def
get_target_function
(
self
)
->
callable
:
"""
This function is used to get the target function for the node handler.
The target function is used to analyze the costs of strategies.
"""
if
self
.
node
.
op
in
(
'placeholder'
,
'get_attr'
,
'output'
):
return
None
if
self
.
node
.
op
==
'call_module'
:
target
=
self
.
node
.
graph
.
owning_module
.
get_submodule
(
self
.
node
.
target
)
elif
self
.
node
.
op
==
'call_function'
:
target
=
self
.
node
.
target
elif
self
.
node
.
op
==
'call_method'
:
target
=
getattr
(
self
.
node
.
args
[
0
].
_meta_data
.
__class__
,
self
.
node
.
target
)
else
:
raise
ValueError
(
f
'Unsupported node type:
{
self
.
node
.
op
}
'
)
return
target
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
Register different sharding strategies for the current node.
...
...
@@ -151,6 +219,38 @@ class NodeHandler(ABC):
pass
class
MetaInfoNodeHandler
(
NodeHandler
):
"""
This is a base class to handle the nodes patched in the meta profiler.
Note: this class will be integrated into the NodeHandler class in the future, after
all the functions are patched.
"""
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super
().
register_strategy
(
compute_resharding_cost
=
compute_resharding_cost
)
target
=
self
.
get_target_function
()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if
meta_register
.
has
(
target
.
__class__
)
or
meta_register
.
has
(
target
):
metainfo_vector
=
[]
for
strategy
in
self
.
strategies_vector
:
metainfo
=
MetaInfo
(
strategy
,
target
)
strategy
.
compute_cost
=
metainfo
.
compute_cost
strategy
.
memory_cost
=
metainfo
.
memory_cost
metainfo_vector
.
append
(
metainfo
)
# attach metainfos to the handler
setattr
(
self
,
"metainfo_vector"
,
metainfo_vector
)
return
self
.
strategies_vector
class
ModuleHandler
(
NodeHandler
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -168,3 +268,35 @@ class ModuleHandler(NodeHandler):
self
.
module
=
module
self
.
named_parameters
=
named_parameters
self
.
named_buffers
=
named_buffers
class
MetaInfoModuleHandler
(
ModuleHandler
):
"""
This is a base class to handle the module patched in the meta profiler.
Note: this class will be integrated into the ModuleHandler class in the future, after
all the modules are patched.
"""
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
True
)
->
StrategiesVector
:
"""
This method is inherited from NodeHandler. It will register the strategies first,
and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class.
"""
super
().
register_strategy
(
compute_resharding_cost
=
compute_resharding_cost
)
target
=
self
.
get_target_function
()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if
meta_register
.
has
(
target
.
__class__
)
or
meta_register
.
has
(
target
):
metainfo_vector
=
[]
for
strategy
in
self
.
strategies_vector
:
metainfo
=
MetaInfo
(
strategy
,
target
)
strategy
.
compute_cost
=
metainfo
.
compute_cost
strategy
.
memory_cost
=
metainfo
.
memory_cost
metainfo_vector
.
append
(
metainfo
)
# attach metainfos to the handler
setattr
(
self
,
"metainfo_vector"
,
metainfo_vector
)
return
self
.
strategies_vector
colossalai/auto_parallel/tensor_shard/node_handler/normal_pooling_handler.py
View file @
e532679c
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
ModuleHandler
from
.node_handler
import
MetaInfoModuleHandler
,
ModuleHandler
from
.registry
import
operator_registry
from
.strategy
import
NormalPoolStrategyGenerator
,
StrategyGenerator
...
...
@@ -16,7 +16,7 @@ __all__ = ['NormPoolingHandler']
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool1d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool3d
)
class
NormPoolingHandler
(
ModuleHandler
):
class
NormPoolingHandler
(
MetaInfo
ModuleHandler
):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
...
...
colossalai/auto_parallel/tensor_shard/node_handler/output_handler.py
View file @
e532679c
...
...
@@ -2,38 +2,51 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
colossalai.device.device_mesh
import
DeviceMesh
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
.node_handler
import
NodeHandler
from
.strategy
import
OutputGenerator
,
StrategyGenerator
__all__
=
[
'OuputHandler'
]
__all__
=
[
'Ou
t
putHandler'
]
class
OuputHandler
(
NodeHandler
):
class
Ou
t
putHandler
(
NodeHandler
):
"""
A OuputHandler which deals with the sharding strategies for Output Node.
A Ou
t
putHandler which deals with the sharding strategies for Output Node.
"""
def
__init__
(
self
,
node
:
torch
.
fx
.
node
.
Node
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
output_option
:
str
)
->
None
:
super
().
__init__
(
node
,
device_mesh
,
strategies_vector
)
self
.
output_option
=
output_option
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
OutputGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
predecessor_node
))
generators
.
append
(
OutputGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
predecessor_node
,
self
.
output_option
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
dummy_output
=
torch
.
empty
(
1
,).
to
(
"meta"
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
dummy_output
)
mapping
=
{
"output"
:
physical_output
}
mapping
=
{}
output_meta_data
=
[]
for
index
,
input_node
in
enumerate
(
self
.
predecessor_node
):
if
not
hasattr
(
input_node
,
"_meta_data"
):
print
(
input_node
.
name
)
physical_inputs
=
OperationData
(
name
=
str
(
input_node
),
type
=
OperationDataType
.
ARG
,
data
=
input_node
.
_meta_data
)
input_meta_data
=
input_node
.
_meta_data
physical_inputs
=
OperationData
(
name
=
str
(
input_node
),
type
=
OperationDataType
.
ARG
,
data
=
input_meta_data
)
name_key
=
f
'input_
{
index
}
'
mapping
[
name_key
]
=
physical_inputs
output_meta_data
.
append
(
input_meta_data
)
assert
len
(
output_meta_data
)
>
0
,
f
'Output node
{
self
.
node
}
has no input node.'
if
len
(
output_meta_data
)
==
1
:
output_meta_data
=
output_meta_data
[
0
]
else
:
output_meta_data
=
tuple
(
output_meta_data
)
self
.
node
.
_meta_data
=
output_meta_data
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
[
"output"
]
=
physical_output
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/placeholder_handler.py
View file @
e532679c
from
typing
import
Dict
,
List
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
torch.fx.node
import
Node
from
colossalai.device.device_mesh
import
DeviceMesh
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
.node_handler
import
NodeHandler
from
.strategy
import
PlaceholderGenerator
,
StrategyGenerator
__all__
=
[
'Placeho
d
lerHandler'
]
__all__
=
[
'Placehol
d
erHandler'
]
class
Placeho
d
lerHandler
(
NodeHandler
):
class
Placehol
d
erHandler
(
NodeHandler
):
"""
A Placeho
d
lerHandler which deals with the sharding strategies for Placeholder Node.
A Placehol
d
erHandler which deals with the sharding strategies for Placeholder Node.
"""
def
__init__
(
self
,
node
:
Node
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
placeholder_option
:
str
)
->
None
:
super
().
__init__
(
node
,
device_mesh
,
strategies_vector
)
self
.
placeholder_option
=
placeholder_option
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
PlaceholderGenerator
(
op_data_mapping
,
self
.
device_mesh
))
generators
.
append
(
PlaceholderGenerator
(
op_data_mapping
,
self
.
device_mesh
,
placeholder_option
=
self
.
placeholder_option
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
...
...
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
View file @
e532679c
...
...
@@ -8,6 +8,11 @@ class Registry:
def
register
(
self
,
source
):
def
wrapper
(
func
):
if
isinstance
(
source
,
(
list
,
tuple
)):
# support register a list of items for this func
for
element
in
source
:
self
.
store
[
element
]
=
func
else
:
self
.
store
[
source
]
=
func
return
func
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
e532679c
...
...
@@ -3,18 +3,17 @@ from typing import Dict, List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.node_handler
import
MetaInfoNodeHandler
,
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
ReshapeGenerator
,
StrategyGenerator
__all__
=
[
'ReshapeHandler'
]
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permut
e
)
@
operator_registry
.
register
(
torch
.
Tensor
.
unsqueez
e
)
@
operator_registry
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
class
ReshapeHandler
(
NodeHandler
):
class
ReshapeHandler
(
MetaInfo
NodeHandler
):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
...
...
@@ -25,13 +24,47 @@ class ReshapeHandler(NodeHandler):
generators
.
append
(
ReshapeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
infer_logical_shape
(
self
,
data
):
"""
This function is used to infer logical shape for operands.
Notes: This function is only used for the operands whose data are not only in type of tensor,
such as tuple of tensor.
"""
if
isinstance
(
data
,
torch
.
Tensor
):
return
data
.
shape
else
:
assert
isinstance
(
data
,
tuple
),
"input_data should be a tuple of tensor or a tensor."
logical_shape
=
[]
for
tensor
in
data
:
assert
isinstance
(
tensor
,
torch
.
Tensor
),
"input_data should be a tuple of tensor or a tensor."
logical_shape
.
append
(
tensor
.
shape
)
logical_shape
=
tuple
(
logical_shape
)
return
logical_shape
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
input_logical_shape
=
self
.
infer_logical_shape
(
input_data
)
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
type
=
data_type
,
data
=
input_data
,
logical_shape
=
input_logical_shape
)
output_data
=
self
.
node
.
_meta_data
output_logical_shape
=
self
.
infer_logical_shape
(
output_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
"input"
:
physical_input_operand
,
"output"
:
physical_output
}
...
...
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
0 → 100644
View file @
e532679c
from
typing
import
Dict
,
List
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
SoftmaxGenerator
,
StrategyGenerator
__all__
=
[
'SoftmaxHandler'
]
@
operator_registry
.
register
(
torch
.
nn
.
Softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
class
SoftmaxHandler
(
NodeHandler
):
"""
A SoftmaxHandler which deals with the sharding strategies for
torch.nn.Softmax or torch.nn.functional.softmax.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
SoftmaxGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
softmax_dim
=
self
.
node
.
kwargs
[
'dim'
]
num_dims
=
self
.
node
.
args
[
0
].
_meta_data
.
dim
()
# recover negative value to positive
if
softmax_dim
<
0
:
softmax_dim
+=
num_dims
physical_dim_operand
=
OperationData
(
name
=
'softmax_dim'
,
type
=
OperationDataType
.
ARG
,
data
=
softmax_dim
)
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"softmax_dim"
:
physical_dim_operand
,
"output"
:
physical_output_operand
}
return
mapping
Prev
1
2
3
4
5
6
7
8
9
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment