Unverified Commit 4ea49cb5 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] added a decorator for address already in use error with backward compatibility (#760)

* [test] added a decorator for address already in use error with backward compatibility

* [test] added a decorator for address already in use error with backward compatibility
parent 10ef8afd
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize, rerun_on_exception from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
__all__ = [ __all__ = [
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
'rerun_on_exception' 'rerun_on_exception', 'rerun_if_address_is_in_use'
] ]
import re import re
import torch
from typing import Callable, List, Any from typing import Callable, List, Any
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from packaging import version
def parameterize(argument: str, values: List[Any]) -> Callable: def parameterize(argument: str, values: List[Any]) -> Callable:
...@@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non ...@@ -144,3 +146,29 @@ def rerun_on_exception(exception_type: Exception = Exception, pattern: str = Non
return _run_until_success return _run_until_success
return _wrapper return _wrapper
def rerun_if_address_is_in_use():
"""
This function reruns a wrapped function if "address already in use" occurs
in testing spawned with torch.multiprocessing
Usage::
@rerun_if_address_is_in_use()
def test_something():
...
"""
# check version
torch_version = version.parse(torch.__version__)
assert torch_version.major == 1
# only torch >= 1.8 has ProcessRaisedException
if torch_version.minor >= 8:
exception = torch.multiprocessing.ProcessRaisedException
else:
exception = Exception
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
return func_wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment