common_utils.py 444 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
'''
This file contains common utility functions for running the unit tests on ROCM.
'''

import torch

TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1'

## Wrapper to skip the unit tests.
def skipIfRocm(fn):
    @wraps(fn)
    def wrapper(*args, **kwargs):
        if TEST_WITH_ROCM:
            raise unittest.SkipTest("test doesn't currently work on ROCm stack.")
        else:
            fn(*args, **kwargs)
    return wrapper