decorator.py 958 Bytes
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Unittest decorator helpers."""

import os
import unittest
8
9
import functools
from pathlib import Path
10
11
12
13
14

cuda_test = unittest.skipIf(os.environ.get('SB_TEST_CUDA', '1') == '0', 'Skip CUDA tests.')
rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.')

pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.')
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36


def load_data(filepath):
    """Decorator to load data file.

    Args:
        filepath (str): Data file path, e.g., tests/data/output.log.

    Returns:
        func: decorated function, data variable is assigned to last argument.
    """
    with Path(filepath).open() as fp:
        data = fp.read()

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, data, **kwargs)

        return wrapper

    return decorator