Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
......@@ -8,10 +8,10 @@ jobs:
detect:
name: Detect file change
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
......@@ -27,10 +27,10 @@ jobs:
- name: Locate base commit
id: locate-base-sha
run: |
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed extension-related files
id: find-extension-change
......@@ -63,7 +63,6 @@ jobs:
echo "$file was changed"
done
build:
name: Build and Test Colossal-AI
needs: detect
......@@ -124,7 +123,7 @@ jobs:
- name: Execute Unit Testing
if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: |
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1
......
import os
import tempfile
from contextlib import nullcontext
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
......@@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__':
......
import os
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic
......@@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
......@@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use()
def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, strategy=strategy)
if __name__ == '__main__':
......
......@@ -10,7 +10,8 @@ from colossalai.context import Config
from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import MultiTimer, free_port
from colossalai.testing import free_port
from colossalai.utils import MultiTimer
from .models import MLP
......
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus
from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
free_port,
parameterize,
rerun_if_address_is_in_use,
rerun_on_exception,
skip_if_not_enough_gpus,
spawn,
)
__all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus'
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
]
import gc
import random
import re
import torch
from typing import Callable, List, Any
import socket
from functools import partial
from inspect import signature
from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version
......@@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
# > davis: hello
# > davis: bye
# > davis: stop
Args:
argument (str): the name of the argument to parameterize
values (List[Any]): a list of values to iterate for this argument
......@@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method():
print('hey')
raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
......@@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
Args:
exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message.
pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message,
the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun foreven if exception keeps occurings
"""
......@@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
return _execute_by_gpu_num
return _wrap_func
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.
Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...
spawn(do_something, nprocs=8)
# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...
spawn(do_something, nprocs=8, arg1=1, arg2=2)
Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)
def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.
Usage:
@clear_cache_before_run()
def test_something():
...
"""
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)
return _clear_cache
return _wrap_func
......@@ -7,7 +7,6 @@ from .common import (
count_zeros_fp32,
disposable,
ensure_path_exists,
free_port,
is_ddp_ignored,
is_dp_rank_0,
is_model_parallel_parameter,
......@@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer
__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
'is_ddp_ignored',
......
......@@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True)
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode.
......
......@@ -4,3 +4,4 @@ packaging
tensornvme
psutil
transformers
pytest
......@@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial
import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn
import torch
......@@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
test_dist_cases(4)
......
......@@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
```python
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0
from colossalai.utils import print_rank_0
from functools import partial
import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port
from colossalai.testing import spawn
import torch
......@@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist_tests, world_size)
if __name__ == '__main__':
test_dist_cases(4)
......
import os
import random
from functools import partial
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components
......@@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
......@@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp)
if __name__ == '__main__':
......
import time
import pytest
import argparse
from functools import partial
import time
import pytest
import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args():
parser = argparse.ArgumentParser()
......@@ -24,6 +24,7 @@ def parse_args():
parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024
......@@ -33,13 +34,16 @@ def train_gpt(args):
# build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = GPTLMLoss()
start_time = time.time()
model = model_builder()
model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2
param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
......@@ -74,21 +78,20 @@ def train_gpt(args):
torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_type: {solver_type} | model_type: {model_type}')
print(
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
print(time_list)
def run(rank, world_size, port, args):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args)
if __name__ == '__main__':
args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args)
mp.spawn(run_func, nprocs=1)
spawn(run, 1, args=args)
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import psutil
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
......
import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench, data_gen_resnet
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port):
......@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, 1)
if __name__ == "__main__":
......
......@@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args):
......@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
spawn(_benchmark, world_size, args=args)
if __name__ == "__main__":
......
......@@ -12,3 +12,4 @@ contexttimer
einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_naive_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 1)
if __name__ == '__main__':
......
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
......@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_torch_amp():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 1)
if __name__ == '__main__':
......
......@@ -3,7 +3,7 @@ import torch
from packaging import version
from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
try:
from colossalai._analyzer.fx import symbolic_trace
......@@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment