Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
import torch import torch
import torch.nn as nn
import transformers import transformers
from torch.fx import GraphModule from torch.fx import GraphModule
...@@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass ...@@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
...@@ -20,6 +19,7 @@ HIDDEN_DIM = 384 ...@@ -20,6 +19,7 @@ HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls): def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
......
...@@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule ...@@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(nn.Module): class LinearModel(nn.Module):
...@@ -26,6 +28,7 @@ class LinearModel(nn.Module): ...@@ -26,6 +28,7 @@ class LinearModel(nn.Module):
@pytest.mark.skip('meta tensor has some bugs in 1.11') @pytest.mark.skip('meta tensor has some bugs in 1.11')
@clear_cache_before_run()
def test_liveness_analysis(): def test_liveness_analysis():
model = LinearModel() model = LinearModel()
tracer = ColoTracer(bias_addition_split=True) tracer = ColoTracer(bias_addition_split=True)
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.meta_profiler import meta_register from colossalai.auto_parallel.meta_profiler import meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh from colossalai.testing.utils import clear_cache_before_run, parameterize
from colossalai.fx import ColoGraphModule, ColoTracer from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize('func', [ @parameterize('func', [
torch.nn.functional.softmax, torch.nn.functional.softmax,
torch.nn.functional.relu, torch.nn.functional.relu,
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
...@@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer ...@@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
...@@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): ...@@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match(): def test_binary_elementwise_meta_concrete_info_match():
world_size = 4 spawn(_binary_elementwise_mem_test, 4)
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
...@@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module): ...@@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
return nn.functional.conv2d(input, self.conv_weight) return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port): def _conv_module_mem_test(rank, world_size, port, bias):
"""This function is for conv memory test """This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
...@@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port): ...@@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False): def test_conv_meta_concrete_info_match(bias=False):
world_size = 4 spawn(_conv_module_mem_test, 4, bias=bias)
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _conv_function_mem_test(rank, world_size, port): def _conv_function_mem_test(rank, world_size, port):
...@@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port): ...@@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match(): def test_conv_function_concrete_info_match():
world_size = 4 spawn(_conv_function_mem_test, 4)
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register from colossalai.auto_parallel.meta_profiler import meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_embedding_meta_info(): def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding) meta_func = meta_register.get(torch.nn.Embedding)
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
class MyModule(nn.Module): class MyModule(nn.Module):
...@@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port): ...@@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_module_meta_concrete_info_match(): def test_linear_module_meta_concrete_info_match():
world_size = 4 spawn(_linear_module_mem_test, 4)
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _linear_function_mem_test(rank, world_size, port): def _linear_function_mem_test(rank, world_size, port):
...@@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port): ...@@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match(): def test_linear_function_meta_concrete_info_match():
world_size = 4 spawn(_linear_function_mem_test, 4)
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run, parameterize
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
...@@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0': ...@@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize( @parameterize(
'tensor_shapes', 'tensor_shapes',
[ [
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register from colossalai.auto_parallel.meta_profiler import meta_register
def _batchnorm_module_mem_test(rank, world_size, port): def _batchnorm_module_mem_test(rank, world_size, port):
...@@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): ...@@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match(): def test_batchnorm_meta_concrete_info_match():
world_size = 4 spawn(_batchnorm_module_mem_test, 4)
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
...@@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): ...@@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_adaptiveavgpool_meta_concrete_info_match(): def test_adaptiveavgpool_meta_concrete_info_match():
world_size = 4 spawn(_adaptiveavgpool_module_mem_test, 4)
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
def _maxpool_module_mem_test(rank, world_size, port): def _maxpool_module_mem_test(rank, world_size, port):
...@@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port): ...@@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_maxpool_meta_concrete_info_match(): def test_maxpool_meta_concrete_info_match():
world_size = 4 spawn(_maxpool_module_mem_test, 4)
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.testing.utils import clear_cache_before_run
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
...@@ -37,6 +20,7 @@ class SplitModule(nn.Module): ...@@ -37,6 +20,7 @@ class SplitModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_tensor_meta_info(): def test_tensor_meta_info():
"""test tensor related meta information """test tensor related meta information
We will just use torch.Tensor.split for the test We will just use torch.Tensor.split for the test
......
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0': if torch.__version__ >= '1.12.0':
...@@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0': ...@@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_where_meta_info(): def test_where_meta_info():
meta_func = meta_register.get(torch.where) meta_func = meta_register.get(torch.where)
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
...@@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module): ...@@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
return output return output
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module(using_kwargs).cuda() model = module(using_kwargs).cuda()
...@@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por ...@@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
@parameterize('using_kwargs', [True, False]) @parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape, using_kwargs): def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4 spawn(
run_func = partial(check_2d_device_mesh, check_2d_device_mesh,
module=module, 4,
bias_shape=bias_shape, module=module,
world_size=world_size, bias_shape=bias_shape,
using_kwargs=using_kwargs, using_kwargs=using_kwargs,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip("skip due to bias cases not ready") @pytest.mark.skip("skip due to bias cases not ready")
...@@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): ...@@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
@parameterize('using_kwargs', [True, False]) @parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape, using_kwargs): def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4 spawn(
run_func = partial(check_1d_device_mesh, check_1d_device_mesh,
module=module, 4,
bias_shape=bias_shape, module=module,
using_kwargs=using_kwargs, bias_shape=bias_shape,
world_size=world_size, using_kwargs=using_kwargs,
port=free_port()) )
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module): ...@@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
return x return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if model_cls == AddmmModel: if model_cls == AddmmModel:
...@@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) ...@@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) @parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_addmm_handler(input_shape, model_cls): def test_addmm_handler(input_shape, model_cls):
world_size = 4 spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port): ...@@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bn_module_handler(): def test_bn_module_handler():
world_size = 4 spawn(check_bn_module_handler, 4)
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
WEIGHT_SHAPE = (32, 16) WEIGHT_SHAPE = (32, 16)
...@@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port): ...@@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(): def test_linear_handler():
world_size = 4 spawn(check_linear_module_handler)
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
...@@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module): ...@@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
return x return x
def check_linear_module_handler(rank, bias, world_size, port): def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda() model = LinearModule(16, 32, bias=bias).cuda()
...@@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port): ...@@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_linear_handler(bias=True): def test_linear_handler(bias=True):
world_size = 4 spawn(check_linear_module_handler, bias=bias)
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
...@@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module): ...@@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
return out return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port): def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
...@@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo ...@@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_tensor(op, other_dim): def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4 spawn(
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, check_binary_elementwise_handler_with_tensor,
op=op, 4,
other_dim=other_dim, op=op,
world_size=world_size, other_dim=other_dim,
port=free_port()) )
mp.spawn(run_func_tensor, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
...@@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): ...@@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4 spawn(
run_func_int = partial(check_binary_elementwise_handler_with_int, check_binary_elementwise_handler_with_int,
op=op, 4,
model_cls=model_cls, op=op,
other_dim=other_dim, model_cls=model_cls,
world_size=world_size, other_dim=other_dim,
port=free_port()) )
mp.spawn(run_func_int, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
...@@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port): ...@@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_bmm_handler(module): def test_bmm_handler(module):
world_size = 4 spawn(check_2d_device_mesh, 4, module=module)
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) spawn(check_1d_device_mesh, 4, module=module)
mp.spawn(run_func_2d, nprocs=world_size)
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_1d, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.graph_module import ColoGraphModule
...@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat ...@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_conv_module_handler(rank, bias, world_size, port): def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
...@@ -155,7 +150,7 @@ class ConvModel(nn.Module): ...@@ -155,7 +150,7 @@ class ConvModel(nn.Module):
return x return x
def check_conv_function_handler(rank, bias, world_size, port): def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda() model = ConvModel().cuda()
...@@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port): ...@@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
# @parameterize('bias', [True, False]) # @parameterize('bias', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_module_handler(bias=False): def test_conv_module_handler(bias=False):
world_size = 4 spawn(check_conv_module_handler, 4, bias=bias)
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
...@@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False): ...@@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
# @parameterize('bias', [True, False]) # @parameterize('bias', [True, False])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_conv_function_handler(bias=False): def test_conv_function_handler(bias=False):
world_size = 4 spawn(check_conv_function_handler, 4, bias=bias)
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment