Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f6032ddb
Unverified
Commit
f6032ddb
authored
Nov 08, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 08, 2022
Browse files
[autoparallel] fix bias addition module (#1800)
parent
6e9730d7
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
438 additions
and
20 deletions
+438
-20
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+6
-5
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+69
-2
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+8
-1
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
...l/tensor_shard/node_handler/strategy/reshape_generator.py
+6
-2
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
.../bias_addition_patch/patched_bias_addition_module/conv.py
+1
-1
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
..._parallel/test_tensor_shard/test_bias_addition_forward.py
+172
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
...r_shard/test_node_handler/test_bias_linear_module_node.py
+146
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+29
-8
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
f6032ddb
...
...
@@ -93,7 +93,7 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
# substitute the origin node with shape_consistency_node
origin_index_args
=
new_args
.
index
(
node
)
new_args
[
origin_index_args
]
=
shape_consistency_node
user_node
.
args
=
new_args
user_node
.
args
=
tuple
(
new_args
)
elif
str
(
node
)
in
new_kwargs
:
# substitute the origin node with shape_consistency_node
new_kwargs
[
str
(
node
)]
=
shape_consistency_node
...
...
@@ -118,10 +118,12 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
comm_actions
=
node
.
best_strategy
.
communication_actions
for
op_data
,
comm_action
in
comm_actions
.
items
():
if
op_data
.
type
==
OperationDataType
.
PARAM
:
if
comm_action
.
comm_type
==
CommType
.
HOOK
:
continue
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
if
comm_action
.
key_for_kwarg
is
not
None
:
if
op_data
.
type
==
OperationDataType
.
OUTPUT
:
comm_object
=
node
elif
comm_action
.
key_for_kwarg
is
not
None
:
comm_object
=
node
.
kwargs
[
comm_action
.
key_for_kwarg
]
else
:
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
...
...
@@ -140,7 +142,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_args
=
list
(
node
.
args
)
new_args
[
comm_action
.
arg_index
]
=
comm_spec_apply_node
node
.
args
=
new_args
node
.
args
=
tuple
(
new_args
)
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
with
mod_graph
.
inserting_after
(
node
):
...
...
@@ -163,7 +165,6 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
f6032ddb
...
...
@@ -5,7 +5,12 @@ import torch
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommAction
,
CommType
,
OperationDataType
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationDataType
,
ShardingStrategy
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
_all_reduce
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
@@ -42,7 +47,32 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if
node
.
op
==
'get_attr'
:
assert
len
(
target_sharding_specs
)
==
1
,
f
'sharing weight is not supported in current version.'
new_sharding_spec
=
target_sharding_specs
[
0
]
user_node
=
node
.
strategies_vector
.
successor_nodes
[
0
]
user_strategy
=
node
.
strategies_vector
.
successor_nodes
[
0
].
best_strategy
op_data_in_user
=
user_strategy
.
get_op_data_by_name
(
str
(
node
))
origin_node_sharding_spec_dict
[
index
]
=
new_sharding_spec
origin_pending_strategy
=
node
.
best_strategy
origin_op_data
=
origin_pending_strategy
.
get_op_data_by_name
(
str
(
node
))
new_sharding_specs
=
origin_pending_strategy
.
sharding_specs
new_sharding_specs
[
origin_op_data
]
=
new_sharding_spec
new_communication_actions
=
{}
if
op_data_in_user
in
user_strategy
.
communication_actions
:
new_communication_action
=
user_strategy
.
communication_actions
.
pop
(
op_data_in_user
)
new_communication_action
.
arg_index
=
0
new_communication_actions
[
origin_op_data
]
=
new_communication_action
new_strategy
=
ShardingStrategy
(
name
=
str
(
new_sharding_spec
.
sharding_sequence
),
sharding_specs
=
new_sharding_specs
,
compute_cost
=
origin_pending_strategy
.
compute_cost
,
communication_cost
=
origin_pending_strategy
.
communication_cost
,
memory_cost
=
origin_pending_strategy
.
memory_cost
,
communication_actions
=
new_communication_actions
)
setattr
(
node
,
'best_strategy'
,
new_strategy
)
setattr
(
node
,
'sharding_spec'
,
new_sharding_spec
)
comm_action_dict
=
{}
for
op_data
,
comm_action
in
node
.
best_strategy
.
communication_actions
.
items
():
comm_action_dict
[
op_data
.
name
]
=
comm_action
...
...
@@ -111,6 +141,43 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
for
name
,
buffer_sharded
in
sharded_buffer_dict
.
items
():
setattr
(
target_module
,
name
,
buffer_sharded
.
detach
().
clone
())
if
node
.
op
==
'get_attr'
:
root
=
node
.
graph
.
owning_module
atoms
=
node
.
target
.
split
(
"."
)
attr_len
=
len
(
atoms
)
if
attr_len
==
1
:
target_module
=
root
target
=
getattr
(
root
,
atoms
[
0
])
else
:
target_module
=
root
.
get_submodule
(
atoms
[
-
2
])
target
=
getattr
(
target_module
,
atoms
[
-
1
])
target_sharding_spec
=
node
.
sharding_spec
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
target
.
shape
,
{})
setattr
(
target
,
'sharding_spec'
,
origin_sharding_spec
)
# TODO: build a ColoParamter class to manager the distributed parameters
target_sharded
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
else
:
target_sharded
=
target
setattr
(
target_module
,
atoms
[
-
1
],
target_sharded
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
# register hook to the parameters
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
)
param
.
register_hook
(
hook_fn
)
wrapper
(
target_sharded
,
comm_spec_to_use
)
return
gm
...
...
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
f6032ddb
...
...
@@ -29,8 +29,15 @@ class ReshapeHandler(NodeHandler):
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# check if the input operand is a parameter
if
isinstance
(
self
.
node
.
args
[
0
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
else
:
data_type
=
OperationDataType
.
ARG
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
type
=
data_type
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
View file @
f6032ddb
...
...
@@ -96,7 +96,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
arg_index
=
0
)
input_comm_action
.
comm_spec
.
gather_dim
=
total_mesh_dim_list
el
se
:
el
if
len
(
total_mesh_dim_list
)
>=
2
:
source_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
source_spec
.
entire_shape
,
...
...
@@ -104,7 +104,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
comm_spec
=
{
'src_spec'
:
source_spec
,
'tgt_spec'
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_comm_action
else
:
input_comm_action
=
None
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
colossalai/fx/tracer/bias_addition_patch/patched_bias_addition_module/conv.py
View file @
f6032ddb
...
...
@@ -43,7 +43,7 @@ class BiasAdditionConv(BiasAdditionModule):
bias_shape
[
0
]
=
-
1
bias_reshape_node_kind
=
'call_method'
bias_reshape_node_target
=
'view'
bias_reshape_node_args
=
(
self
.
bias_proxy
,
bias_shape
)
bias_reshape_node_args
=
(
self
.
bias_proxy
,
torch
.
Size
(
bias_shape
)
)
bias_reshape_proxy
=
self
.
tracer
.
create_proxy
(
bias_reshape_node_kind
,
bias_reshape_node_target
,
bias_reshape_node_args
,
{})
return
bias_reshape_proxy
...
...
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
f6032ddb
...
...
@@ -58,7 +58,7 @@ def torch_bmm(input, mat2, *, out=None):
@
meta_patched_function
.
register
(
torch
.
nn
.
functional
.
linear
)
def
torch_linear
(
input
,
mat2
,
*
,
out
=
None
):
def
torch_linear
(
input
,
mat2
,
bias
=
None
,
*
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
output_shape
=
list
(
input
.
shape
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
0 → 100644
View file @
f6032ddb
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationDataType
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
class
LinearModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
in_features
,
out_features
)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
x
=
x
*
2
return
x
class
ConvModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
bias
=
True
):
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
bias
=
bias
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
x
*
2
return
x
def
check_linear_module
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
LinearModel
(
4
,
8
).
cuda
()
input
=
torch
.
rand
(
4
,
4
).
cuda
()
output_compare
=
model
(
input
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
tracer
=
ColoTracer
()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
node_list
=
list
(
graph
.
nodes
)
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
linear_node
=
node_list
[
3
]
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
output_compare
)
def
check_conv_module
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
ConvModel
(
3
,
6
,
2
).
cuda
()
input
=
torch
.
rand
(
4
,
3
,
64
,
64
).
cuda
()
output_compare
=
model
(
input
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
tracer
=
ColoTracer
()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
3
,
64
,
64
).
to
(
'meta'
)})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
node_list
=
list
(
graph
.
nodes
)
conv_node
=
node_list
[
3
]
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
output_compare
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_bias_addition_module
():
world_size
=
4
run_func_linear
=
partial
(
check_linear_module
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_linear
,
nprocs
=
world_size
)
run_func_conv
=
partial
(
check_conv_module
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_conv
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_bias_addition_module
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py
0 → 100644
View file @
f6032ddb
from
faulthandler
import
disable
from
functools
import
partial
from
xml.dom
import
WrongDocumentErr
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
typing_extensions
import
Self
from
colossalai.auto_parallel.tensor_shard.node_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.utils
import
parameterize
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
LinearModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
,
bias
):
super
().
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
in_features
,
out_features
,
bias
=
bias
)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
return
x
def
check_linear_module_handler
(
rank
,
bias
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
LinearModule
(
16
,
32
,
bias
=
bias
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
2
,
2
,
4
,
16
).
cuda
()
# the index of linear node in computation graph
node_index
=
3
# strategy number of linear node
strategy_number
=
10
# construct input args
input_args
=
[
input
]
# construct meta arg names
meta_arg_names
=
[
'x'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
,
node_type
=
'bias_module'
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
linear_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
LinearFunctionHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"x"
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"linear_weight"
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
assert
'bias'
not
in
mapping
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one strategy will be converted to different physical sharding spec
assert
len
(
strategy_name_list
)
>
8
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear_weight'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_linear_handler
(
bias
=
True
):
world_size
=
4
run_func_module
=
partial
(
check_linear_module_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_module
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_linear_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
f6032ddb
...
...
@@ -7,6 +7,9 @@ from torch.fx import GraphModule
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.solver
import
SolverOptions
,
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
to_global
...
...
@@ -56,7 +59,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
strategy_number
:
int
,
input_args
:
List
[
torch
.
Tensor
],
meta_arg_names
:
List
[
str
],
input_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
=
{}):
input_kwargs
:
Dict
[
str
,
torch
.
Tensor
]
=
{},
node_type
:
str
=
'normal'
):
for
strategy_index
in
range
(
strategy_number
):
print
(
f
'#strategy_index:
{
strategy_index
}
'
)
# We need to copy the model to avoid do backward more than once in same graph
...
...
@@ -79,11 +83,21 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
target_node
=
list
(
graph
.
nodes
)[
node_index
]
# solution construction
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution
=
[
0
]
*
solution_len
solution
[
node_index
]
=
strategy_index
if
node_type
==
'normal'
:
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution
=
[
0
]
*
solution_len
solution
[
node_index
]
=
strategy_index
else
:
node_vector
=
strategies_constructor
.
leaf_strategies
[
node_index
]
strategy_to_keep
=
node_vector
[
strategy_index
]
node_vector
=
[
strategy_to_keep
]
# solution construction
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
...
...
@@ -110,11 +124,18 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# extract the strategy used in this iter
strategy_in_use
=
target_node
.
strategies_vector
[
strategy_index
]
param_to_shard_dict
=
dict
(
m
odel_to_shard
.
named_parameters
())
param_to_shard_dict
=
dict
(
g
m
.
named_parameters
())
param_to_compare_dict
=
dict
(
model_to_compare
.
named_parameters
())
for
name
in
param_to_shard_dict
.
keys
():
param_name
=
name
.
split
(
'.'
)[
-
1
]
param_sharding_spec
=
strategy_in_use
.
get_sharding_spec_by_name
(
param_name
)
if
node_type
==
'normal'
:
param_sharding_spec
=
strategy_in_use
.
get_sharding_spec_by_name
(
param_name
)
else
:
if
'weight'
in
name
:
param_sharding_spec
=
list
(
graph
.
nodes
)[
4
].
sharding_spec
elif
'bias'
in
name
:
param_sharding_spec
=
list
(
graph
.
nodes
)[
5
].
sharding_spec
grad_sharded
=
param_to_shard_dict
[
name
].
grad
grad_to_compare
=
param_to_compare_dict
[
name
].
grad
global_grad
=
to_global
(
grad_sharded
,
param_sharding_spec
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment