Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4851f2d6
Unverified
Commit
4851f2d6
authored
Dec 26, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 26, 2022
Browse files
[autoparallel] update_getattr_handler (#2193)
parent
f10ce01e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
133 additions
and
55 deletions
+133
-55
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+11
-14
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+1
-6
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
...l/tensor_shard/node_handler/strategy/getattr_generator.py
+47
-11
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
...test_tensor_shard/test_node_handler/test_addmm_handler.py
+52
-21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py
...st_tensor_shard/test_node_handler/test_getattr_handler.py
+10
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+12
-2
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
4851f2d6
...
...
@@ -6,6 +6,7 @@ import torch
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.tensor_shard.constants
import
RESHAPE_FUNC_OP
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
...
...
@@ -96,27 +97,23 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
# to the same strategy of the user node.
if
node
.
op
==
'get_attr'
:
assert
len
(
target_sharding_specs
)
==
1
,
f
'sharing weight is not supported in current version.'
new_sharding_spec
=
target_sharding_specs
[
0
]
user_strategy
=
node
.
strategies_vector
.
successor_nodes
[
0
].
best_strategy
op_data_in_user
=
user_strategy
.
get_op_data_by_name
(
str
(
node
))
origin_node_sharding_spec_dict
[
index
]
=
new_sharding_spec
target_node
=
node
.
strategies_vector
.
successor_nodes
[
0
]
node_name
=
str
(
node
)
if
target_node
.
op
==
'call_function'
and
target_node
.
target
in
RESHAPE_FUNC_OP
:
node_name
=
str
(
target_node
)
target_node
=
target_node
.
strategies_vector
.
successor_nodes
[
0
]
user_strategy
=
target_node
.
best_strategy
op_data_in_user
=
user_strategy
.
get_op_data_by_name
(
node_name
)
origin_pending_strategy
=
node
.
best_strategy
origin_op_data
=
origin_pending_strategy
.
get_op_data_by_name
(
str
(
node
))
new_sharding_specs
=
origin_pending_strategy
.
sharding_specs
new_sharding_specs
[
origin_op_data
]
=
new_sharding_spec
new_communication_actions
=
{}
if
op_data_in_user
in
user_strategy
.
communication_actions
:
new_communication_action
=
user_strategy
.
communication_actions
.
pop
(
op_data_in_user
)
new_communication_action
.
arg_index
=
0
new_communication_actions
[
origin_op_data
]
=
new_communication_action
new_strategy
=
ShardingStrategy
(
name
=
str
(
new_sharding_spec
.
sharding_sequence
),
sharding_specs
=
new_sharding_specs
,
compute_cost
=
origin_pending_strategy
.
compute_cost
,
communication_cost
=
origin_pending_strategy
.
communication_cost
,
memory_cost
=
origin_pending_strategy
.
memory_cost
,
communication_actions
=
new_communication_actions
)
setattr
(
node
,
'best_strategy'
,
new_strategy
)
setattr
(
node
,
'sharding_spec'
,
new_sharding_spec
)
node
.
best_strategy
.
communication_actions
=
new_communication_actions
comm_action_dict
=
{}
for
op_data
,
comm_action
in
node
.
best_strategy
.
communication_actions
.
items
():
comm_action_dict
[
op_data
.
name
]
=
comm_action
...
...
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
4851f2d6
...
...
@@ -86,12 +86,7 @@ class NodeHandler(ABC):
if
prev_sharding_spec
is
None
:
return
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
elif
isinstance
(
prev_sharding_spec
,
ShardingSpec
):
if
isinstance
(
data
,
torch
.
nn
.
parameter
.
Parameter
):
# we won't compute the resharding cost for the parameters,
# since the parameters will be sharded before runtime and
# not converted during runtime.
return
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
elif
isinstance
(
data
,
torch
.
Tensor
):
if
isinstance
(
data
,
torch
.
Tensor
):
dtype
=
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
_
,
_
,
consistency_cost
=
shape_consistency_manager
.
shape_consistency
(
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
View file @
4851f2d6
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
ignore_sharding_exception
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
.strategy_generator
import
StrategyGenerator
...
...
@@ -37,17 +43,47 @@ class GetattrGenerator(StrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
dim_partition_dict_mapping
=
{
"output"
:
{},
}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
@
ignore_sharding_exception
def
enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# we check for the output logical shape to get the number of dimensions
dim_partition_list
=
[]
dim_size
=
len
(
self
.
op_data
[
'output'
].
logical_shape
)
# enumerate all the 2D sharding cases
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
dim_partition_list
.
extend
(
sharding_list_2d
)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0
=
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
dim_size
)
dim_partition_list
.
extend
(
sharding_list_1d_on_dim_0
)
sharding_list_1d_on_dim_1
=
enumerate_all_possible_1d_sharding
(
mesh_dim_1
,
dim_size
)
dim_partition_list
.
extend
(
sharding_list_1d_on_dim_1
)
# add empty dict for fully replicated case
dim_partition_list
.
append
({})
name
=
'Replica Attribute'
# sharding strategy bookkeeping
strategy_list
=
[]
strategy
=
self
.
get_
sharding
_
strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
# convert these dim partition dict to
sharding
strategy
for
dim_partition_dict
in
dim_partition_list
:
dim_partition_dict_mapping
=
dict
(
output
=
dim_partition_dict
)
return
[
strategy
]
try
:
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
communication_action_mapping
=
{}
# get name
name
=
f
"get_attr
{
sharding_spec_mapping
[
'output'
].
sharding_sequence
}
"
sharding_strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
sharding_strategy
)
except
ShardingSpecException
:
continue
return
strategy_list
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
return
self
.
enumerate_all_possible_output
(
0
,
1
)
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py
View file @
4851f2d6
...
...
@@ -35,25 +35,59 @@ class AddmmModel(nn.Module):
return
x
def
check_linear_function_handler
(
rank
,
input_shape
,
world_size
,
port
):
class
AddmmModel_with_param
(
nn
.
Module
):
def
__init__
(
self
,
weight_shape
,
bias_shape
):
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
weight_shape
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
rand
(
bias_shape
))
def
forward
(
self
,
m1
):
x
=
torch
.
addmm
(
self
.
bias
,
m1
,
self
.
weight
,
beta
=
3
,
alpha
=
2
)
return
x
def
check_addmm_function_handler
(
rank
,
input_shape
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
AddmmModel
().
cuda
()
if
model_cls
==
AddmmModel
:
model
=
AddmmModel
().
cuda
()
else
:
model
=
AddmmModel_with_param
(
weight_shape
=
(
8
,
16
),
bias_shape
=
input_shape
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
input_shape
).
cuda
()
m1
=
torch
.
rand
(
4
,
8
).
cuda
()
m2
=
torch
.
rand
(
8
,
16
).
cuda
()
# the index of addmm node in computation graph
node_index
=
4
# strategy number of linear node
strategy_number
=
14
# construct input args
input_args
=
[
input
,
m1
,
m2
]
# construct meta arg names
meta_arg_names
=
[
'input'
,
'm1'
,
'm2'
]
if
model_cls
==
AddmmModel
:
input
=
torch
.
rand
(
input_shape
).
cuda
()
m1
=
torch
.
rand
(
4
,
8
).
cuda
()
m2
=
torch
.
rand
(
8
,
16
).
cuda
()
# construct input args
input_args
=
[
input
,
m1
,
m2
]
# construct meta arg names
meta_arg_names
=
[
'input'
,
'm1'
,
'm2'
]
meta_args_for_tracer
=
{}
for
meta_arg
,
input_arg
in
zip
(
meta_arg_names
,
input_args
):
meta_args_for_tracer
[
meta_arg
]
=
input_arg
.
to
(
'meta'
)
# the index of addmm node in computation graph
node_index
=
4
# strategy number of linear node
strategy_number
=
14
else
:
m1
=
torch
.
rand
(
4
,
8
).
cuda
()
# construct input args
input_args
=
[
m1
]
# construct meta arg names
meta_arg_names
=
[
'm1'
]
# the index of addmm node in computation graph
meta_args_for_tracer
=
{}
for
meta_arg
,
input_arg
in
zip
(
meta_arg_names
,
input_args
):
meta_args_for_tracer
[
meta_arg
]
=
input_arg
.
to
(
'meta'
)
node_index
=
4
# strategy number of linear node
strategy_number
=
14
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
...
...
@@ -73,12 +107,7 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
# %mul_1 : [#users=1] = call_function[target=operator.mul](args = (2, %linear), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%mul_1, %mul), kwargs = {})
# return add
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
input_shape
).
to
(
'meta'
),
'm1'
:
torch
.
rand
(
4
,
8
).
to
(
'meta'
),
'm2'
:
torch
.
rand
(
8
,
16
).
to
(
'meta'
),
})
graph
=
tracer
.
trace
(
model
,
meta_args
=
meta_args_for_tracer
)
gm
=
ColoGraphModule
(
model
,
graph
)
# [input_1, m1, m2, addmm, output]
node_list
=
list
(
graph
.
nodes
)
...
...
@@ -155,11 +184,13 @@ def check_linear_function_handler(rank, input_shape, world_size, port):
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'input_shape'
,
[(
16
,),
(
4
,
16
)])
@
parameterize
(
'model_cls'
,
[
AddmmModel
,
AddmmModel_with_param
])
@
rerun_if_address_is_in_use
()
def
test_addmm_handler
(
input_shape
):
def
test_addmm_handler
(
input_shape
,
model_cls
):
world_size
=
4
run_func_function
=
partial
(
check_
linear
_function_handler
,
run_func_function
=
partial
(
check_
addmm
_function_handler
,
input_shape
=
input_shape
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_function
,
nprocs
=
world_size
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py
View file @
4851f2d6
...
...
@@ -39,6 +39,7 @@ def test_getattr_handler():
strategies_vector
=
getattr_strategies_vector
)
getattr_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
# check operation data mapping
mapping
=
getattr_handler
.
get_operation_data_mapping
()
...
...
@@ -51,7 +52,15 @@ def test_getattr_handler():
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
((
16
,
4
,
3
,
3
))
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategy_name_list
=
[
val
.
name
for
val
in
getattr_handler
.
strategies_vector
]
assert
"Replica Attribute"
in
strategy_name_list
assert
'get_attr [S0, S1, R, R]'
in
strategy_name_list
assert
'get_attr [S1, S0, R, R]'
in
strategy_name_list
assert
'get_attr [S01, R, R, R]'
in
strategy_name_list
assert
'get_attr [R, S01, R, R]'
in
strategy_name_list
assert
'get_attr [S0, R, R, R]'
in
strategy_name_list
assert
'get_attr [R, S0, R, R]'
in
strategy_name_list
assert
'get_attr [S1, R, R, R]'
in
strategy_name_list
assert
'get_attr [R, S1, R, R]'
in
strategy_name_list
assert
'get_attr [R, R, R, R]'
in
strategy_name_list
if
__name__
==
'__main__'
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
4851f2d6
...
...
@@ -149,10 +149,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
param_sharding_spec
=
strategy_in_use
.
get_sharding_spec_by_name
(
param_name
)
else
:
if
'weight'
in
name
:
param_sharding_spec
=
list
(
graph
.
nodes
)[
4
].
sharding_spec
param_sharding_spec
=
None
for
node
in
list
(
graph
.
nodes
):
if
'weight'
in
node
.
name
:
param_sharding_spec
=
node
.
sharding_spec
elif
'bias'
in
name
:
param_sharding_spec
=
list
(
graph
.
nodes
)[
5
].
sharding_spec
param_sharding_spec
=
None
for
node
in
list
(
graph
.
nodes
):
if
'bias'
in
node
.
name
:
param_sharding_spec
=
node
.
sharding_spec
assert
param_sharding_spec
is
not
None
grad_sharded
=
param_to_shard_dict
[
name
].
grad
grad_to_compare
=
param_to_compare_dict
[
name
].
grad
global_grad
=
to_global
(
grad_sharded
,
param_sharding_spec
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment