Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3a462151
"vscode:/vscode.git/clone" did not exist on "ce7ade3882680ddc18a43375a71adaed194c6da4"
Unverified
Commit
3a462151
authored
Sep 23, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 23, 2022
Browse files
[autoparallel] add embedding handler (#1620)
parent
69448f64
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
176 additions
and
0 deletions
+176
-0
colossalai/auto_parallel/solver/op_handler/embedding_handler.py
...alai/auto_parallel/solver/op_handler/embedding_handler.py
+176
-0
No files found.
colossalai/auto_parallel/solver/op_handler/embedding_handler.py
0 → 100644
View file @
3a462151
import
operator
from
functools
import
reduce
import
warnings
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
colossalai.auto_parallel.solver._utils
import
exception_handler
__all__
=
[
'EmbeddingHandler'
]
class
EmbeddingHandler
(
OperatorHandler
):
"""
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
input_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_compute_cost
(
self
,
total_sharding_size
):
input_shape
=
self
.
input_data
.
shape
weight_shape
=
self
.
weight
.
shape
input_shape_product
=
reduce
(
operator
.
mul
,
input_shape
,
1
)
weight_shape_product
=
reduce
(
operator
.
mul
,
weight_shape
,
1
)
compute_cost
=
input_shape_product
*
weight_shape_product
*
2
/
total_sharding_size
return
compute_cost
def
_generate_memory_cost
(
self
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel_output
=
self
.
output_data
.
numel
()
numel_input
=
self
.
input_data
.
numel
()
numel_weight
=
self
.
weight
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# forward memory_cost
memory_cost_forward_activation
=
numel_output
*
size_per_elem_bytes
/
sharding_size_forward
memory_cost_forward_weight
=
numel_weight
*
size_per_elem_bytes
/
sharding_size_weight
memory_cost_forward
=
memory_cost_forward_activation
+
memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation
=
numel_input
*
size_per_elem_bytes
/
sharding_size_backward_activation
memory_cost_backward_weight
=
numel_weight
*
size_per_elem_bytes
/
sharding_size_weight
memory_cost_backward
=
memory_cost_backward_activation
+
memory_cost_backward_weight
# memory_cost pair
memory_cost
=
(
memory_cost_forward
,
memory_cost_backward
)
return
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
memory_cost_backward_weight
@
exception_handler
def
split_weight_both_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RRS
{
mesh_dim_1
}
= RR x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
dim_partition_dict_for_input
=
{}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
2
:
[
mesh_dim_1
]}
sharding_spec_for_output
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
total_sharding_size
=
self
.
device_mesh
.
shape
[
0
]
*
self
.
device_mesh
.
shape
[
1
]
compute_cost
=
self
.
_generate_compute_cost
(
total_sharding_size
)
# compute the memory cost of this strategy
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
1
sharding_size_weight
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# compute the communication cost of this strategy during forward phase
communication_cost_forward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_forward_activation
,
mesh_dim_0
)
# compute the communication cost of this strategy during backward phase
communication_cost_backward
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_activation
,
mesh_dim_1
)
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
split_input_both_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x RR'
dim_partition_dict_for_input
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_output
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
total_sharding_size
=
self
.
device_mesh
.
shape
[
0
]
*
self
.
device_mesh
.
shape
[
1
]
compute_cost
=
self
.
_generate_compute_cost
(
total_sharding_size
)
# compute the memory cost of this strategy
sharding_size_forward
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_backward_activation
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size_weight
=
1
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward
=
0
# compute the communication cost of this strategy during backward phase
communication_cost_backward_activation
=
0
communication_cost_backward_weight
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
0
)
communication_cost_backward
=
communication_cost_backward_activation
+
communication_cost_backward_weight
communication_cost
=
communication_cost_forward
+
communication_cost_backward
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
register_strategy
(
self
)
->
StrategiesVector
:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
'''
# RRS = RR x SS
self
.
split_weight_both_dim
(
0
,
1
)
self
.
split_weight_both_dim
(
1
,
0
)
# SSR = SS x RR
self
.
split_input_both_dim
(
0
,
1
)
self
.
split_input_both_dim
(
1
,
0
)
return
self
.
strategies_vector
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment