"...chart/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "602352ce190bcb02013c62c2337e8b8678015699"
Unverified Commit 80eba05b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
parent 62f4e2eb
...@@ -8,10 +8,10 @@ jobs: ...@@ -8,10 +8,10 @@ jobs:
detect: detect:
name: Detect file change name: Detect file change
if: | if: |
github.event.pull_request.draft == false && github.event.pull_request.draft == false &&
github.base_ref == 'main' && github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' &&
contains( github.event.pull_request.labels.*.name, 'Run Build and Test') contains( github.event.pull_request.labels.*.name, 'Run Build and Test')
outputs: outputs:
changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }}
anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }}
...@@ -27,10 +27,10 @@ jobs: ...@@ -27,10 +27,10 @@ jobs:
- name: Locate base commit - name: Locate base commit
id: locate-base-sha id: locate-base-sha
run: | run: |
curBranch=$(git rev-parse --abbrev-ref HEAD) curBranch=$(git rev-parse --abbrev-ref HEAD)
commonCommit=$(git merge-base origin/main $curBranch) commonCommit=$(git merge-base origin/main $curBranch)
echo $commonCommit echo $commonCommit
echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT
- name: Find the changed extension-related files - name: Find the changed extension-related files
id: find-extension-change id: find-extension-change
...@@ -63,7 +63,6 @@ jobs: ...@@ -63,7 +63,6 @@ jobs:
echo "$file was changed" echo "$file was changed"
done done
build: build:
name: Build and Test Colossal-AI name: Build and Test Colossal-AI
needs: detect needs: detect
...@@ -124,7 +123,7 @@ jobs: ...@@ -124,7 +123,7 @@ jobs:
- name: Execute Unit Testing - name: Execute Unit Testing
if: needs.detect.outputs.anyLibraryFileChanged == 'true' if: needs.detect.outputs.anyLibraryFileChanged == 'true'
run: | run: |
PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/
env: env:
DATA: /data/scratch/cifar-10 DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1 NCCL_SHM_DISABLE: 1
......
import os import os
import tempfile import tempfile
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from coati.models.gpt import GPTActor from coati.models.gpt import GPTActor
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
...@@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy): ...@@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_checkpoint(world_size, strategy): def test_checkpoint(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from coati.experience_maker import NaiveExperienceMaker from coati.experience_maker import NaiveExperienceMaker
from coati.models.base import RewardModel from coati.models.base import RewardModel
from coati.models.gpt import GPTActor, GPTCritic from coati.models.gpt import GPTActor, GPTCritic
...@@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer ...@@ -13,8 +11,7 @@ from coati.replay_buffer import NaiveReplayBuffer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy
from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
...@@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy): ...@@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy):
@pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) @pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_data(world_size, strategy): def test_data(world_size, strategy):
run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) spawn(run_dist, world_size, strategy=strategy)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -10,7 +10,8 @@ from colossalai.context import Config ...@@ -10,7 +10,8 @@ from colossalai.context import Config
from colossalai.context.random import reset_seeds from colossalai.context.random import reset_seeds
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import MultiTimer, free_port from colossalai.testing import free_port
from colossalai.utils import MultiTimer
from .models import MLP from .models import MLP
......
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus from .pytest_wrapper import run_on_environment_flag
from .utils import (
clear_cache_before_run,
free_port,
parameterize,
rerun_if_address_is_in_use,
rerun_on_exception,
skip_if_not_enough_gpus,
spawn,
)
__all__ = [ __all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus' 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn',
'clear_cache_before_run', 'run_on_environment_flag'
] ]
import gc
import random
import re import re
import torch import socket
from typing import Callable, List, Any
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from typing import Any, Callable, List
import torch
import torch.multiprocessing as mp
from packaging import version from packaging import version
...@@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable: ...@@ -43,7 +48,7 @@ def parameterize(argument: str, values: List[Any]) -> Callable:
# > davis: hello # > davis: hello
# > davis: bye # > davis: bye
# > davis: stop # > davis: stop
Args: Args:
argument (str): the name of the argument to parameterize argument (str): the name of the argument to parameterize
values (List[Any]): a list of values to iterate for this argument values (List[Any]): a list of values to iterate for this argument
...@@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non ...@@ -85,13 +90,13 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
def test_method(): def test_method():
print('hey') print('hey')
raise RuntimeError('Address already in use') raise RuntimeError('Address already in use')
# rerun for infinite times if Runtime error occurs # rerun for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, max_try=None) @rerun_on_exception(exception_type=RuntimeError, max_try=None)
def test_method(): def test_method():
print('hey') print('hey')
raise RuntimeError('Address already in use') raise RuntimeError('Address already in use')
# rerun only the exception message is matched with pattern # rerun only the exception message is matched with pattern
# for infinite times if Runtime error occurs # for infinite times if Runtime error occurs
@rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
...@@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non ...@@ -101,10 +106,10 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
Args: Args:
exception_type (Exception, Optional): The type of exception to detect for rerun exception_type (Exception, Optional): The type of exception to detect for rerun
pattern (str, Optional): The pattern to match the exception message. pattern (str, Optional): The pattern to match the exception message.
If the pattern is not None and matches the exception message, If the pattern is not None and matches the exception message,
the exception will be detected for rerun the exception will be detected for rerun
max_try (int, Optional): Maximum reruns for this function. The default value is 5. max_try (int, Optional): Maximum reruns for this function. The default value is 5.
If max_try is None, it will rerun foreven if exception keeps occurings If max_try is None, it will rerun foreven if exception keeps occurings
""" """
...@@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int): ...@@ -202,3 +207,72 @@ def skip_if_not_enough_gpus(min_gpus: int):
return _execute_by_gpu_num return _execute_by_gpu_num
return _wrap_func return _wrap_func
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def spawn(func, nprocs=1, **kwargs):
"""
This function is used to spawn processes for testing.
Usage:
# must contians arguments rank, world_size, port
def do_something(rank, world_size, port):
...
spawn(do_something, nprocs=8)
# can also pass other arguments
def do_something(rank, world_size, port, arg1, arg2):
...
spawn(do_something, nprocs=8, arg1=1, arg2=2)
Args:
func (Callable): The function to be spawned.
nprocs (int, optional): The number of processes to spawn. Defaults to 1.
"""
port = free_port()
wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs)
mp.spawn(wrapped_func, nprocs=nprocs)
def clear_cache_before_run():
"""
This function is a wrapper to clear CUDA and python cache before executing the function.
Usage:
@clear_cache_before_run()
def test_something():
...
"""
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
torch.cuda.synchronize()
gc.collect()
f(*args, **kwargs)
return _clear_cache
return _wrap_func
...@@ -7,7 +7,6 @@ from .common import ( ...@@ -7,7 +7,6 @@ from .common import (
count_zeros_fp32, count_zeros_fp32,
disposable, disposable,
ensure_path_exists, ensure_path_exists,
free_port,
is_ddp_ignored, is_ddp_ignored,
is_dp_rank_0, is_dp_rank_0,
is_model_parallel_parameter, is_model_parallel_parameter,
...@@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer ...@@ -37,7 +36,6 @@ from .timer import MultiTimer, Timer
__all__ = [ __all__ = [
'checkpoint', 'checkpoint',
'free_port',
'print_rank_0', 'print_rank_0',
'sync_model_param', 'sync_model_param',
'is_ddp_ignored', 'is_ddp_ignored',
......
...@@ -50,23 +50,6 @@ def ensure_path_exists(filename: str): ...@@ -50,23 +50,6 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True) Path(dirpath).mkdir(parents=True, exist_ok=True)
def free_port() -> int:
"""Get a free port on localhost.
Returns:
int: A free port on localhost.
"""
while True:
port = random.randint(20000, 65000)
try:
with socket.socket() as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError:
continue
def sync_model_param(model, parallel_mode): def sync_model_param(model, parallel_mode):
r"""Make sure data parameters are consistent during Data Parallel Mode. r"""Make sure data parameters are consistent during Data Parallel Mode.
......
...@@ -4,3 +4,4 @@ packaging ...@@ -4,3 +4,4 @@ packaging
tensornvme tensornvme
psutil psutil
transformers transformers
pytest
...@@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp ...@@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp
```python ```python
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0 from colossalai.utils import print_rank_0
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port from colossalai.testing import spawn
import torch import torch
...@@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port): ...@@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}") print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size): def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) spawn(run_dist_tests, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_dist_cases(4) test_dist_cases(4)
......
...@@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs. ...@@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.
```python ```python
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.utils import free_port, print_rank_0 from colossalai.utils import print_rank_0
from functools import partial from functools import partial
import colossalai import colossalai
from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
from colossalai.utils import free_port from colossalai.testing import spawn
import torch import torch
...@@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port): ...@@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port):
print_rank_0(f"shape {t1.shape}, {t1.data}") print_rank_0(f"shape {t1.shape}, {t1.data}")
def test_dist_cases(world_size): def test_dist_cases(world_size):
run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) spawn(run_dist_tests, world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_dist_cases(4) test_dist_cases(4)
......
import os import os
import random import random
from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from vit import get_training_components from vit import get_training_components
...@@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -15,8 +13,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
...@@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp): ...@@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp):
@pytest.mark.parametrize('use_ddp', [False, True]) @pytest.mark.parametrize('use_ddp', [False, True])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_vit(world_size, use_ddp): def test_vit(world_size, use_ddp):
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) spawn(run_dist, world_size, use_ddp=use_ddp)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import time
import pytest
import argparse import argparse
from functools import partial import time
import pytest
import torch import torch
from model_zoo import GPTLMLoss, get_gpt2_components
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.fx.profiler import parameter_size
from colossalai.utils import free_port, get_current_device
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from model_zoo import get_gpt2_components, GPTLMLoss from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import spawn
from colossalai.utils import get_current_device
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -24,6 +24,7 @@ def parse_args(): ...@@ -24,6 +24,7 @@ def parse_args():
parser.add_argument('--memory_budget', type=float, default=16) parser.add_argument('--memory_budget', type=float, default=16)
return parser.parse_args() return parser.parse_args()
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
def train_gpt(args): def train_gpt(args):
memory_budget = args.memory_budget * 1024 * 1024 * 1024 memory_budget = args.memory_budget * 1024 * 1024 * 1024
...@@ -33,13 +34,16 @@ def train_gpt(args): ...@@ -33,13 +34,16 @@ def train_gpt(args):
# build model # build model
model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) label = torch.randint(low=0, high=128, size=(
64,
8,
), device=get_current_device())
criterion = GPTLMLoss() criterion = GPTLMLoss()
start_time = time.time() start_time = time.time()
model = model_builder() model = model_builder()
model.train() model.train()
param_size = parameter_size(model) / 1024 ** 2 / 2 param_size = parameter_size(model) / 1024**2 / 2
init_time = time.time() - start_time init_time = time.time() - start_time
print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
...@@ -74,21 +78,20 @@ def train_gpt(args): ...@@ -74,21 +78,20 @@ def train_gpt(args):
torch.cuda.synchronize() torch.cuda.synchronize()
exec_time = sum(sorted(time_list)[:5]) / 5 exec_time = sum(sorted(time_list)[:5]) / 5
runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2
runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2
print(f'solver_type: {solver_type} | model_type: {model_type}') print(f'solver_type: {solver_type} | model_type: {model_type}')
print( print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|')
f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
)
print(time_list) print(time_list)
def run(rank, world_size, port, args): def run(rank, world_size, port, args):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
train_gpt(args) train_gpt(args)
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
run_func = partial(run, world_size=1, port=free_port(), args=args) spawn(run, 1, args=args)
mp.spawn(run_func, nprocs=1)
from functools import partial from functools import partial
from time import time from time import time
from typing import Dict, Optional, Tuple, Union
import psutil import psutil
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
......
import time
from argparse import ArgumentParser
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import bench, data_gen_resnet from bench_utils import bench, data_gen_resnet
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.testing import spawn
def _benchmark(rank, world_size, port): def _benchmark(rank, world_size, port):
...@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port): ...@@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port):
def auto_activation_checkpoint_batchsize_benchmark(): def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1 spawn(_benchmark, 1)
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,14 +4,13 @@ from functools import partial ...@@ -4,14 +4,13 @@ from functools import partial
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.testing import spawn
def _benchmark(rank, world_size, port, args): def _benchmark(rank, world_size, port, args):
...@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args): ...@@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args):
def auto_activation_checkpoint_benchmark(args): def auto_activation_checkpoint_benchmark(args):
world_size = 1 world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) spawn(_benchmark, world_size, args=args)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,3 +12,4 @@ contexttimer ...@@ -12,3 +12,4 @@ contexttimer
einops einops
triton==2.0.0.dev20221202 triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): ...@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_naive_amp(): def test_naive_amp():
world_size = 1 spawn(run_dist, 1)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
import copy import copy
from functools import partial
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
...@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): ...@@ -87,10 +84,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_torch_amp(): def test_torch_amp():
world_size = 1 spawn(run_dist, 1)
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from packaging import version from packaging import version
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from colossalai.testing.utils import parameterize from colossalai.testing.utils import clear_cache_before_run, parameterize
try: try:
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
...@@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module): ...@@ -81,6 +81,7 @@ class AddmmModel(torch.nn.Module):
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@clear_cache_before_run()
@parameterize("bias", [True, False]) @parameterize("bias", [True, False])
@parameterize("bias_addition_split", [True, False]) @parameterize("bias_addition_split", [True, False])
@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment