Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0e52f3d3
Unverified
Commit
0e52f3d3
authored
Oct 13, 2022
by
Frank Lee
Committed by
GitHub
Oct 13, 2022
Browse files
[unittest] supported condititonal testing based on env var (#1701)
polish code
parent
8283e95d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
36 additions
and
10 deletions
+36
-10
colossalai/testing/pytest_wrapper.py
colossalai/testing/pytest_wrapper.py
+17
-0
colossalai/testing/utils.py
colossalai/testing/utils.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py
...st_deprecated_op_handler/test_deprecated_where_handler.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py
...test_deprecated/test_deprecated_shape_consistency_pass.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
...st_tensor_shard/test_deprecated/test_deprecated_solver.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py
..._shard/test_deprecated/test_deprecated_solver_with_gpt.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py
..._shard/test_deprecated/test_deprecated_solver_with_mlp.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
...l/test_tensor_shard/test_node_handler/test_bmm_handler.py
+3
-2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
...nsor_shard/test_node_handler/test_norm_pooling_handler.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
..._parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+2
-1
No files found.
colossalai/testing/pytest_wrapper.py
0 → 100644
View file @
0e52f3d3
import
pytest
import
os
def
run_on_environment_flag
(
name
:
str
):
"""
Conditionally run a test based on the environment variable. If this environment variable is set
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
"""
assert
isinstance
(
name
,
str
)
flag
=
os
.
environ
.
get
(
name
.
upper
(),
'0'
)
reason
=
f
'Environment varialbe
{
name
}
is
{
flag
}
'
if
flag
==
'1'
:
return
pytest
.
mark
.
skipif
(
False
,
reason
=
reason
)
else
:
return
pytest
.
mark
.
skipif
(
True
,
reason
=
reason
)
colossalai/testing/utils.py
View file @
0e52f3d3
...
@@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
...
@@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
"""
"""
def
_wrap_func
(
f
):
def
_wrap_func
(
f
):
def
_execute_by_gpu_num
(
*
args
,
**
kwargs
):
def
_execute_by_gpu_num
(
*
args
,
**
kwargs
):
num_avail_gpu
=
torch
.
cuda
.
device_count
()
num_avail_gpu
=
torch
.
cuda
.
device_count
()
if
num_avail_gpu
>=
min_gpus
:
if
num_avail_gpu
>=
min_gpus
:
f
(
*
args
,
**
kwargs
)
f
(
*
args
,
**
kwargs
)
return
_execute_by_gpu_num
return
_execute_by_gpu_num
return
_wrap_func
return
_wrap_func
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py
View file @
0e52f3d3
...
@@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio
...
@@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
ConvModel
(
nn
.
Module
):
class
ConvModel
(
nn
.
Module
):
...
@@ -22,7 +23,7 @@ class ConvModel(nn.Module):
...
@@ -22,7 +23,7 @@ class ConvModel(nn.Module):
return
output
return
output
@
pytest
.
mark
.
skip
(
"temporarily skipped"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_where_handler
():
def
test_where_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py
View file @
0e52f3d3
...
@@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
...
@@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
from
colossalai.fx.passes.experimental.adding_shape_consistency_pass
import
shape_consistency_pass
,
solution_annotatation_pass
from
colossalai.fx.passes.experimental.adding_shape_consistency_pass
import
shape_consistency_pass
,
solution_annotatation_pass
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
ConvModel
(
nn
.
Module
):
class
ConvModel
(
nn
.
Module
):
...
@@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
...
@@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
assert
output
.
equal
(
origin_output
)
assert
output
.
equal
(
origin_output
)
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_apply
():
def
test_apply
():
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
View file @
0e52f3d3
...
@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap
...
@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap
from
copy
import
deepcopy
from
copy
import
deepcopy
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated
import
Solver
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
ConvModel
(
nn
.
Module
):
class
ConvModel
(
nn
.
Module
):
...
@@ -33,7 +34,7 @@ class ConvModel(nn.Module):
...
@@ -33,7 +34,7 @@ class ConvModel(nn.Module):
return
x
return
x
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_solver
():
def
test_solver
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py
View file @
0e52f3d3
...
@@ -15,12 +15,13 @@ import transformers
...
@@ -15,12 +15,13 @@ import transformers
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
BATCH_SIZE
=
8
BATCH_SIZE
=
8
SEQ_LENGHT
=
8
SEQ_LENGHT
=
8
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_cost_graph
():
def
test_cost_graph
():
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
mesh_shape
=
(
2
,
4
)
mesh_shape
=
(
2
,
4
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py
View file @
0e52f3d3
...
@@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50
...
@@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
colossalai.auto_parallel.tensor_shard.deprecated.constants
import
*
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.deprecated.options
import
SolverOptions
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
MLP
(
torch
.
nn
.
Module
):
class
MLP
(
torch
.
nn
.
Module
):
...
@@ -34,7 +35,7 @@ class MLP(torch.nn.Module):
...
@@ -34,7 +35,7 @@ class MLP(torch.nn.Module):
return
x
return
x
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_cost_graph
():
def
test_cost_graph
():
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
mesh_shape
=
(
2
,
4
)
mesh_shape
=
(
2
,
4
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
View file @
0e52f3d3
...
@@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule
...
@@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule
from
colossalai.auto_parallel.solver.node_handler.dot_handler
import
BMMFunctionHandler
from
colossalai.auto_parallel.solver.node_handler.dot_handler
import
BMMFunctionHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
BMMTensorMethodModule
(
nn
.
Module
):
class
BMMTensorMethodModule
(
nn
.
Module
):
...
@@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module):
...
@@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module):
return
torch
.
bmm
(
x1
,
x2
)
return
torch
.
bmm
(
x1
,
x2
)
@
pytest
.
mark
.
skip
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
def
test_2d_device_mesh
(
module
):
...
@@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
...
@@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
@
pytest
.
mark
.
skip
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
def
test_1d_device_mesh
(
module
):
model
=
module
()
model
=
module
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
View file @
0e52f3d3
...
@@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import
...
@@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
import
pytest
import
pytest
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_norm_pool_handler
():
def
test_norm_pool_handler
():
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
View file @
0e52f3d3
...
@@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50
...
@@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.options
import
SolverOptions
from
colossalai.auto_parallel.solver.options
import
SolverOptions
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_cost_graph
():
def
test_cost_graph
():
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
physical_mesh_id
=
torch
.
arange
(
0
,
8
)
mesh_shape
=
(
2
,
4
)
mesh_shape
=
(
2
,
4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment