Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a4d1f59c
"deploy/dynemo/api-server/api/models/base.go" did not exist on "5ddc7f7df5ab77c4efae9fd6ca299c3040c91533"
Unverified
Commit
a4d1f59c
authored
Oct 28, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 28, 2022
Browse files
[autoparallel] add numerical test for handlers (#1769)
parent
b0f7c8bd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
468 additions
and
145 deletions
+468
-145
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
...est_tensor_shard/test_node_handler/test_addbmm_handler.py
+92
-21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
...tensor_shard/test_node_handler/test_batch_norm_handler.py
+49
-21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
...hard/test_node_handler/test_binary_elementwise_handler.py
+80
-21
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
...l/test_tensor_shard/test_node_handler/test_bmm_handler.py
+68
-20
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
.../test_tensor_shard/test_node_handler/test_conv_handler.py
+24
-8
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
...tensor_shard/test_node_handler/test_layer_norm_handler.py
+45
-14
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+90
-31
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+20
-9
No files found.
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
View file @
a4d1f59c
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
AddBMMFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
AddBMMTensorMethodModule
(
nn
.
Module
):
...
...
@@ -20,11 +29,30 @@ class AddBMMTorchFunctionModule(nn.Module):
return
torch
.
addbmm
(
bias
,
x1
,
x2
)
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
def
test_2d_device_mesh
(
module
,
bias_shape
):
model
=
module
()
def
check_2d_device_mesh
(
rank
,
module
,
bias_shape
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
module
().
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
x1
=
torch
.
rand
(
4
,
8
,
16
).
cuda
()
x2
=
torch
.
rand
(
4
,
16
,
8
).
cuda
()
bias
=
torch
.
rand
(
bias_shape
).
cuda
()
# the index of addbmm node in computation graph
node_index
=
3
# strategy number of addbmm node on 2d device mesh
strategy_number
=
7
# construct input args
input_args
=
[
bias
,
x1
,
x2
]
# construct meta arg names
meta_arg_names
=
[
'bias'
,
'x1'
,
'x2'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
...
...
@@ -32,12 +60,8 @@ def test_2d_device_mesh(module, bias_shape):
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
...
...
@@ -78,7 +102,6 @@ def test_2d_device_mesh(module, bias_shape):
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
not
in
strategy_name_list
...
...
@@ -110,10 +133,31 @@ def test_2d_device_mesh(module, bias_shape):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
def
test_1d_device_mesh
(
module
,
bias_shape
):
model
=
module
()
def
check_1d_device_mesh
(
rank
,
module
,
bias_shape
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
model
=
module
().
cuda
()
x1
=
torch
.
rand
(
4
,
8
,
16
).
cuda
()
x2
=
torch
.
rand
(
4
,
16
,
8
).
cuda
()
bias
=
torch
.
rand
(
bias_shape
).
cuda
()
# the index of addbmm node in computation graph
node_index
=
3
# strategy number of addbmm node on 2d device mesh
strategy_number
=
1
# construct input args
input_args
=
[
bias
,
x1
,
x2
]
# construct meta arg names
meta_arg_names
=
[
'bias'
,
'x1'
,
'x2'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
...
...
@@ -121,12 +165,7 @@ def test_1d_device_mesh(module, bias_shape):
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
...
...
@@ -184,6 +223,38 @@ def test_1d_device_mesh(module, bias_shape):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
pytest
.
mark
.
skip
(
"skip due to bias cases not ready"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
@
rerun_if_address_is_in_use
()
def
test_2d_device_mesh
(
module
,
bias_shape
):
world_size
=
4
run_func
=
partial
(
check_2d_device_mesh
,
module
=
module
,
bias_shape
=
bias_shape
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
skip
(
"skip due to bias cases not ready"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
@
rerun_if_address_is_in_use
()
def
test_1d_device_mesh
(
module
,
bias_shape
):
world_size
=
4
run_func
=
partial
(
check_1d_device_mesh
,
module
=
module
,
bias_shape
=
bias_shape
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_1d_device_mesh
()
#
test_2d_device_mesh()
test_2d_device_mesh
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
View file @
a4d1f59c
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler
import
\
BatchNormModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
StrategiesVector
)
from
colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler
import
BatchNormModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
import
pytest
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
@
pytest
.
mark
.
skip
(
"skip due to passes not ready"
)
def
test_bn_module_handler
():
model
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
16
).
to
(
'meta'
))
def
check_bn_module_handler
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
16
)).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
4
,
16
,
64
,
64
).
cuda
()
# the index of bn node in computation graph
node_index
=
1
# the total number of bn strategies without sync bn mode
# TODO: add sync bn stategies after related passes ready
strategy_number
=
4
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
],
meta_arg_names
=
[
'input'
])
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...
...
@@ -20,10 +45,6 @@ def test_bn_module_handler():
# return _0
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
,
64
,
64
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
bn_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
bn_mod_node
)
...
...
@@ -40,25 +61,21 @@ def test_bn_module_handler():
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
...
...
@@ -75,16 +92,27 @@ def test_bn_module_handler():
# RS01 = RS01 x S01
assert
'RS01 = RS01 x S01'
in
strategy_name_list
# temporarily skip the sync bn test
# TODO: test sync bn after the implicit runtime pass completed
# SR = SR x R WITH SYNC_BN
assert
'S0R = S0R x R WITH SYNC_BN'
in
strategy_name_list
assert
'S1R = S1R x R WITH SYNC_BN'
in
strategy_name_list
#
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
#
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
# SS = SS x S WITH SYNC_BN
assert
'S0S1 = S0S1 x S1 WITH SYNC_BN'
in
strategy_name_list
assert
'S1S0 = S1S0 x S0 WITH SYNC_BN'
in
strategy_name_list
#
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
#
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert
'S01R = S01R x R WITH SYNC_BN'
in
strategy_name_list
# assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_bn_module_handler
():
world_size
=
4
run_func
=
partial
(
check_bn_module_handler
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
View file @
a4d1f59c
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
BinaryElementwiseHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
@
parameterize
(
'
op
'
,
[
torch
.
add
])
@
parameterize
(
'other_dim'
,
[
1
,
2
]
)
def
test_binary_elementwise_handler_with_tensor
(
op
,
other_dim
):
def
check_binary_elementwise_handler_with_tensor
(
rank
,
op
,
other_dim
,
world_size
,
port
):
disable_existing_loggers
(
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
class
BinaryElementwiseOpModel
(
nn
.
Module
):
...
...
@@ -22,16 +31,32 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
out
=
self
.
op
(
x1
,
x2
)
return
out
model
=
BinaryElementwiseOpModel
(
op
)
tracer
=
ColoTracer
()
model
=
BinaryElementwiseOpModel
(
op
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
x1
=
torch
.
rand
(
4
,
4
).
cuda
()
x2
=
torch
.
rand
([
4
]
*
other_dim
).
cuda
()
# the index of binary-elementwise node in computation graph
node_index
=
2
# strategy number of binary-elementwise node
strategy_number
=
9
# construct input args
input_args
=
[
x1
,
x2
]
# construct meta arg names
meta_arg_names
=
[
'x1'
,
'x2'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
meta_args
=
{
'x1'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
),
'x2'
:
torch
.
rand
([
4
]
*
other_dim
).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
model
,
meta_args
=
meta_args
)
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
op_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
op_node
)
...
...
@@ -97,9 +122,9 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
assert
input_sharding_spec
.
sharding_sequence
[
-
1
]
==
other_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'
op
'
,
[
torch
.
add
])
@
parameterize
(
'other'
,
[
1
,
2
]
)
def
test_binary_elementwise_handler_with_int
(
op
,
other
):
def
check_binary_elementwise_handler_with_int
(
rank
,
op
,
other_dim
,
world_size
,
port
):
disable_existing_loggers
(
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
class
BinaryElementwiseOpModel
(
nn
.
Module
):
...
...
@@ -112,16 +137,30 @@ def test_binary_elementwise_handler_with_int(op, other):
out
=
self
.
op
(
x1
,
self
.
const
)
return
out
model
=
BinaryElementwiseOpModel
(
op
,
other
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
model
=
BinaryElementwiseOpModel
(
op
,
other_dim
).
cuda
()
x1
=
torch
.
rand
(
4
,
4
).
cuda
()
# the index of binary-elementwise node in computation graph
node_index
=
1
# strategy number of binary-elementwise node
strategy_number
=
9
# construct input args
input_args
=
[
x1
]
# construct meta arg names
meta_arg_names
=
[
'x1'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
meta_args
=
{
'x1'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
model
,
meta_args
=
meta_args
)
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
op_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
op_node
)
...
...
@@ -168,6 +207,26 @@ def test_binary_elementwise_handler_with_int(op, other):
assert
input_sharding_spec
.
sharding_sequence
==
output_sharding_spec
.
sharding_sequence
@
parameterize
(
'op'
,
[
torch
.
add
])
@
parameterize
(
'other_dim'
,
[
1
,
2
])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_binary_elementwise_handler
(
op
,
other_dim
):
world_size
=
4
run_func_tensor
=
partial
(
check_binary_elementwise_handler_with_tensor
,
op
=
op
,
other_dim
=
other_dim
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_tensor
,
nprocs
=
world_size
)
run_func_int
=
partial
(
check_binary_elementwise_handler_with_int
,
op
=
op
,
other_dim
=
other_dim
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_int
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_binary_elementwise_handler_with_tensor
()
test_binary_elementwise_handler_with_int
()
test_binary_elementwise_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
View file @
a4d1f59c
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
BMMFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
BMMTensorMethodModule
(
nn
.
Module
):
...
...
@@ -21,22 +29,37 @@ class BMMTorchFunctionModule(nn.Module):
return
torch
.
bmm
(
x1
,
x2
)
@
parameterize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
model
=
module
()
def
check_2d_device_mesh
(
rank
,
module
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
module
().
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
x1
=
torch
.
rand
(
4
,
8
,
16
).
cuda
()
x2
=
torch
.
rand
(
4
,
16
,
8
).
cuda
()
# the index of bmm node in computation graph
node_index
=
2
# strategy number of bmm node on 2d device mesh
strategy_number
=
7
# construct input args
input_args
=
[
x1
,
x2
]
# construct meta arg names
meta_arg_names
=
[
'x1'
,
'x2'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
...
...
@@ -96,27 +119,41 @@ def test_2d_device_mesh(module):
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bmm'
)
# make sure the sharding matches across different operation data
print
(
input_sharding_spec
.
sharding_sequence
,
output_sharding_spec
.
sharding_sequence
)
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
model
=
module
()
def
check_1d_device_mesh
(
rank
,
module
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
module
().
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
x1
=
torch
.
rand
(
4
,
8
,
16
).
cuda
()
x2
=
torch
.
rand
(
4
,
16
,
8
).
cuda
()
# the index of bmm node in computation graph
node_index
=
2
# strategy number of bmm node on 1d device mesh
strategy_number
=
1
# construct input args
input_args
=
[
x1
,
x2
]
# construct meta arg names
meta_arg_names
=
[
'x1'
,
'x2'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
...
...
@@ -166,6 +203,17 @@ def test_1d_device_mesh(module):
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_bmm_handler
(
module
):
world_size
=
4
run_func_2d
=
partial
(
check_2d_device_mesh
,
module
=
module
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_2d
,
nprocs
=
world_size
)
run_func_1d
=
partial
(
check_1d_device_mesh
,
module
=
module
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_1d
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_1d_device_mesh
()
test_2d_device_mesh
()
test_bmm_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
View file @
a4d1f59c
...
...
@@ -31,11 +31,16 @@ def check_conv_module_handler(rank, bias, world_size, port):
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
# index of conv node in
this
graph
# index of conv node in
computation
graph
node_index
=
1
# total number of conv strategies
strategy_number
=
16
numerical_test_for_node_strategy
(
model
,
device_mesh
,
node_index
,
strategy_number
,
[
input
],
[
'input'
])
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
],
meta_arg_names
=
[
'input'
])
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
...
...
@@ -165,8 +170,13 @@ def check_conv_function_handler(rank, bias, world_size, port):
bias_tensor
=
torch
.
rand
(
16
).
cuda
()
input_kwargs
[
'bias'
]
=
bias_tensor
node_index
+=
1
numerical_test_for_node_strategy
(
model
,
device_mesh
,
node_index
,
strategy_number
,
input_args
,
meta_arg_names
,
input_kwargs
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
,
input_kwargs
=
input_kwargs
)
tracer
=
ColoTracer
()
# graph():
...
...
@@ -280,21 +290,27 @@ def check_conv_function_handler(rank, bias, world_size, port):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
1
]
@
pytest
.
mark
.
skip
(
"some cases need to be fixed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'bias'
,
[
True
,
False
])
# We temporarily ban the bias option before doing bias add
# before all reduce communication may encounter correctness issue.
# @parameterize('bias', [True, False])
@
rerun_if_address_is_in_use
()
def
test_conv_module_handler
(
bias
):
def
test_conv_module_handler
(
bias
=
False
):
world_size
=
4
run_func
=
partial
(
check_conv_module_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
skip
(
"some cases need to be fixed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
parameterize
(
'bias'
,
[
True
,
False
])
# We temporarily ban the bias option before doing bias add
# before all reduce communication may encounter correctness issue.
# @parameterize('bias', [True, False])
@
rerun_if_address_is_in_use
()
def
test_conv_function_handler
(
bias
):
def
test_conv_function_handler
(
bias
=
False
):
world_size
=
4
run_func
=
partial
(
check_conv_function_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py
View file @
a4d1f59c
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler
import
\
LayerNormModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
StrategiesVector
)
from
colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler
import
LayerNormModuleHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
def
test_ln_module_handler
():
model
=
nn
.
Sequential
(
nn
.
LayerNorm
(
16
).
to
(
'meta'
))
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
def
check_ln_module_handler
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
nn
.
Sequential
(
nn
.
LayerNorm
(
16
)).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
4
,
16
).
cuda
()
# the index of bn node in computation graph
node_index
=
1
# the total number of ln strategies
strategy_number
=
4
# construct input args
input_args
=
[
input
]
# construct meta arg names
meta_arg_names
=
[
'input'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...
...
@@ -18,10 +47,7 @@ def test_ln_module_handler():
# return _0
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
ln_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
ln_mod_node
)
...
...
@@ -38,25 +64,21 @@ def test_ln_module_handler():
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
...
...
@@ -74,5 +96,14 @@ def test_ln_module_handler():
assert
'[S01, R] = [S01, R] x [R]'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_ln_module_handler
():
world_size
=
4
run_func
=
partial
(
check_ln_module_handler
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_ln_module_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
a4d1f59c
from
faulthandler
import
disable
from
functools
import
partial
from
xml.dom
import
WrongDocumentErr
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
typing_extensions
import
Self
...
...
@@ -11,22 +17,42 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.utils
import
parameterize
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
@
parameterize
(
'bias'
,
[
True
,
False
])
def
test_linear_module_handler
(
bias
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
,
bias
=
bias
).
to
(
'meta'
))
def
check_linear_module_handler
(
rank
,
bias
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
,
bias
=
bias
)).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
2
,
2
,
4
,
16
).
cuda
()
# the index of linear node in computation graph
node_index
=
1
# strategy number of linear node
strategy_number
=
10
# construct input args
input_args
=
[
input
]
# construct meta arg names
meta_arg_names
=
[
'input'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
...
...
@@ -43,26 +69,22 @@ def test_linear_module_handler(bias):
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
if
bias
:
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
32
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
32
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
...
...
@@ -110,19 +132,49 @@ def test_linear_module_handler(bias):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'bias'
,
[
True
,
False
])
def
test_linear_function_handler
(
bias
):
model
=
nn
.
Linear
(
16
,
32
,
bias
=
bias
).
to
(
'meta'
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
class
LinearModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
,
others
,
bias
=
None
):
x
=
nn
.
functional
.
linear
(
input
,
others
,
bias
=
bias
)
return
x
def
check_linear_function_handler
(
rank
,
bias
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
LinearModel
().
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
2
,
2
,
4
,
16
).
cuda
()
other
=
torch
.
rand
(
32
,
16
).
cuda
()
# the index of linear node in computation graph
node_index
=
2
# strategy number of linear node
strategy_number
=
10
# construct input args
input_args
=
[
input
,
other
]
# construct meta arg names
meta_arg_names
=
[
'input'
,
'others'
]
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
input_args
,
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
),
'others'
:
torch
.
rand
(
32
,
16
).
to
(
'meta'
)
})
gm
=
ColoGraphModule
(
model
,
graph
)
if
bias
:
linear_func_node
=
list
(
graph
.
nodes
)[
3
]
else
:
...
...
@@ -136,26 +188,22 @@ def test_linear_function_handler(bias):
mapping
=
handler
.
get_operation_data_mapping
()
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
name
==
"others"
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
P
AR
AM
assert
mapping
[
'other'
].
type
==
OperationDataType
.
AR
G
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
if
bias
:
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
32
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
P
AR
AM
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
AR
G
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
...
...
@@ -187,7 +235,7 @@ def test_linear_function_handler(bias):
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'
weight
'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'
others
'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
if
bias
:
...
...
@@ -202,6 +250,17 @@ def test_linear_function_handler(bias):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
# @parameterize('bias', [True, False])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_linear_handler
(
bias
=
False
):
world_size
=
4
run_func_module
=
partial
(
check_linear_module_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_module
,
nprocs
=
world_size
)
run_func_function
=
partial
(
check_linear_function_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_function
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_linear_module_handler
()
test_linear_function_handler
()
test_linear_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
a4d1f59c
...
...
@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, Strategi
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
to_global
from
colossalai.testing.comparison
import
assert_close
from
colossalai.testing.comparison
import
assert_close
,
assert_close_loose
def
_build_model_to_compare
(
model
:
torch
.
nn
.
Module
,
input_args
:
List
[
torch
.
Tensor
],
...
...
@@ -31,7 +31,6 @@ def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tenso
arg_to_compare
=
copy
.
deepcopy
(
input_tensor
)
arg_to_compare
.
requires_grad
=
True
wrapper
(
arg_to_compare
,
arg_index
)
# arg_to_compare.register_hook(hook_fn)
args_to_compare
.
append
(
arg_to_compare
)
for
name
,
input_kwarg
in
input_kwargs
.
items
():
...
...
@@ -68,8 +67,6 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
model_to_shard
,
args_to_shard
,
kwargs_to_shard
=
_build_model_to_compare
(
model
,
input_args
,
input_kwargs
,
grad_to_shard_dict
)
zero_tensor
=
torch
.
Tensor
(
0
).
cuda
()
tracer
=
ColoTracer
()
input_sample
=
{}
for
input_arg
,
meta_arg_name
in
zip
(
input_args
,
meta_arg_names
):
...
...
@@ -98,10 +95,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
origin_node_sharding_spec_dict
=
origin_spec_dict
,
comm_actions_dict
=
comm_actions_dict
,
**
kwargs_to_shard
)
# except:
# print(gm)
output_to_compare
=
model_to_compare
(
*
args_to_compare
,
**
kwargs_to_compare
)
assert_close
(
(
output
-
output_to_compare
).
sum
(),
zero_tensor
)
assert_close
_helper
(
output
,
output_to_compare
,
strategy_index
=
strategy_index
,
type
=
'forward output'
)
# backward result compare
loss
=
output
.
sum
()
...
...
@@ -111,7 +106,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
for
key
in
grad_to_shard_dict
.
keys
():
grad_to_shard
=
grad_to_shard_dict
[
key
]
grad_to_compare
=
grad_to_compare_dict
[
key
]
assert_close
(
(
grad_to_shard
-
grad_to_compare
).
sum
(),
zero_tensor
)
assert_close
_helper
(
grad_to_shard
,
grad_to_compare
,
strategy_index
=
strategy_index
,
type
=
'input grad'
)
# extract the strategy used in this iter
strategy_in_use
=
target_node
.
strategies_vector
[
strategy_index
]
...
...
@@ -123,4 +118,20 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
grad_sharded
=
param_to_shard_dict
[
name
].
grad
grad_to_compare
=
param_to_compare_dict
[
name
].
grad
global_grad
=
to_global
(
grad_sharded
,
param_sharding_spec
)
assert_close
((
global_grad
-
grad_to_compare
).
sum
(),
zero_tensor
)
assert_close_helper
(
global_grad
,
grad_to_compare
,
strategy_index
=
strategy_index
,
type
=
'param grad'
)
def
assert_close_helper
(
first
:
torch
.
Tensor
,
second
:
torch
.
Tensor
,
rtol
:
float
=
1e-2
,
atol
:
float
=
1e-2
,
strategy_index
:
int
=
-
1
,
type
:
str
=
'not defined'
):
"""
This method is used to check whether the average difference between two tensors is as close as expected.
"""
# average_diff_tensor = ((first - second)/(second+0.1)).sum()/second.numel()
try
:
assert_close
(
first
,
second
,
rtol
=
rtol
,
atol
=
atol
)
except
:
print
(
f
'strategy index
{
strategy_index
}
encounter assert_close error on
{
type
}
'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment