decorator.py 1.03 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Unittest decorator helpers."""

import os
import unittest
8
9
import functools
from pathlib import Path
10
11
12
13
14

cuda_test = unittest.skipIf(os.environ.get('SB_TEST_CUDA', '1') == '0', 'Skip CUDA tests.')
rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.')

pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.')
15
directx_test = unittest.skipIf(os.environ.get('SB_TEST_DIRECTX', '0') == '0', 'Skip DirectX tests.')
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def load_data(filepath):
    """Decorator to load data file.

    Args:
        filepath (str): Data file path, e.g., tests/data/output.log.

    Returns:
        func: decorated function, data variable is assigned to last argument.
    """
    with Path(filepath).open() as fp:
        data = fp.read()

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, data, **kwargs)

        return wrapper

    return decorator