Unverified Commit 0e52f3d3 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[unittest] supported condititonal testing based on env var (#1701)

polish code
parent 8283e95d
import pytest
import os
def run_on_environment_flag(name: str):
"""
Conditionally run a test based on the environment variable. If this environment variable is set
to 1, this test will be executed. Otherwise, this test is skipped. The environment variable is default to 0.
"""
assert isinstance(name, str)
flag = os.environ.get(name.upper(), '0')
reason = f'Environment varialbe {name} is {flag}'
if flag == '1':
return pytest.mark.skipif(False, reason=reason)
else:
return pytest.mark.skipif(True, reason=reason)
...@@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int): ...@@ -193,11 +193,12 @@ def skip_if_not_enough_gpus(min_gpus: int):
""" """
def _wrap_func(f): def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs): def _execute_by_gpu_num(*args, **kwargs):
num_avail_gpu = torch.cuda.device_count() num_avail_gpu = torch.cuda.device_count()
if num_avail_gpu >= min_gpus: if num_avail_gpu >= min_gpus:
f(*args, **kwargs) f(*args, **kwargs)
return _execute_by_gpu_num return _execute_by_gpu_num
return _wrap_func return _wrap_func
...@@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio ...@@ -7,6 +7,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptio
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -22,7 +23,7 @@ class ConvModel(nn.Module): ...@@ -22,7 +23,7 @@ class ConvModel(nn.Module):
return output return output
@pytest.mark.skip("temporarily skipped") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler(): def test_where_handler():
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
...@@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -18,6 +18,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -72,7 +73,7 @@ def check_apply(rank, world_size, port): ...@@ -72,7 +73,7 @@ def check_apply(rank, world_size, port):
assert output.equal(origin_output) assert output.equal(origin_output)
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_apply(): def test_apply():
......
...@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap ...@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import Grap
from copy import deepcopy from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -33,7 +34,7 @@ class ConvModel(nn.Module): ...@@ -33,7 +34,7 @@ class ConvModel(nn.Module):
return x return x
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_solver(): def test_solver():
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
...@@ -15,12 +15,13 @@ import transformers ...@@ -15,12 +15,13 @@ import transformers
from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 8 BATCH_SIZE = 8
SEQ_LENGHT = 8 SEQ_LENGHT = 8
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph(): def test_cost_graph():
physical_mesh_id = torch.arange(0, 8) physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4) mesh_shape = (2, 4)
......
...@@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50 ...@@ -15,6 +15,7 @@ from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
...@@ -34,7 +35,7 @@ class MLP(torch.nn.Module): ...@@ -34,7 +35,7 @@ class MLP(torch.nn.Module):
return x return x
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph(): def test_cost_graph():
physical_mesh_id = torch.arange(0, 8) physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4) mesh_shape = (2, 4)
......
...@@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule ...@@ -5,6 +5,7 @@ from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class BMMTensorMethodModule(nn.Module): class BMMTensorMethodModule(nn.Module):
...@@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module): ...@@ -19,7 +20,7 @@ class BMMTorchFunctionModule(nn.Module):
return torch.bmm(x1, x2) return torch.bmm(x1, x2)
@pytest.mark.skip @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module): def test_2d_device_mesh(module):
...@@ -90,7 +91,7 @@ def test_2d_device_mesh(module): ...@@ -90,7 +91,7 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list
@pytest.mark.skip @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module): def test_1d_device_mesh(module):
model = module() model = module()
......
...@@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import ...@@ -6,9 +6,10 @@ from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
import pytest import pytest
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler(): def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
......
...@@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50 ...@@ -15,9 +15,10 @@ from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.solver.constants import * from colossalai.auto_parallel.solver.constants import *
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.solver.options import SolverOptions from colossalai.auto_parallel.solver.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@pytest.mark.skip("for higher testing speed") @run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph(): def test_cost_graph():
physical_mesh_id = torch.arange(0, 8) physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4) mesh_shape = (2, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment