Unverified Commit bffd85bf authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

added testing module (#435)

parent dbdc9a77
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
from .utils import parameterize
__all__ = ['assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize']
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
def assert_equal(a: Tensor, b: Tensor):
assert torch.all(a == b), f'expected a and b to be equal but they are not, {a} vs {b}'
def assert_not_equal(a: Tensor, b: Tensor):
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-2, atol: float = 1e-3):
assert_close(a, b, rtol, atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
# all gather tensors from different ranks
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=process_group)
# check if they are equal one by one
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i+1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'
from typing import List, Any
from functools import partial
def parameterize(argument: str, values: List[Any]):
"""
This function is to simulate the same behavior as pytest.mark.parameterize. As
we want to avoid the number of distributed network initialization, we need to have
this extra decorator on the function launched by torch.multiprocessing.
If a function is wrapped with this wrapper, non-paramterized arguments must be keyword arguments,
positioanl arguments are not allowed.
Example 1:
@parameterize('person', ['xavier', 'davis'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something(msg='hello')
This will generate output:
> xavier: hello
> davis: hello
Exampel 2:
@parameterize('person', ['xavier', 'davis'])
@parameterize('msg', ['hello', 'bye', 'stop'])
def say_something(person, msg):
print(f'{person}: {msg}')
say_something()
This will generate output:
> xavier: hello
> xavier: bye
> xavier: stop
> davis: hello
> davis: bye
> davis: stop
"""
def _wrapper(func):
def _execute_function_by_param(**kwargs):
for val in values:
arg_map = {argument: val}
partial_func = partial(func, **arg_map)
partial_func(**kwargs)
return _execute_function_by_param
return _wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment