testing_utils.py 1.61 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
import os
import random
import unittest
from distutils.util import strtobool

Patrick von Platen's avatar
Patrick von Platen committed
6
7
import torch

8
9
from packaging import version

Patrick von Platen's avatar
Patrick von Platen committed
10
11
12

global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
13
14
15
16
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")

if is_torch_higher_equal_than_1_12:
    torch_device = "mps" if torch.backends.mps.is_available() else torch_device
Patrick von Platen's avatar
Patrick von Platen committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def parse_flag_from_env(key, default=False):
    try:
        value = os.environ[key]
    except KeyError:
        # KEY isn't set, default to `default`.
        _value = default
    else:
        # KEY is set, convert it to True or False.
        try:
            _value = strtobool(value)
        except ValueError:
            # More values are supported, but let's keep the message simple.
            raise ValueError(f"If set, {key} must be yes or no.")
    return _value


_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)


def floats_tensor(shape, scale=1.0, rng=None, name=None):
    """Creates a random float32 tensor"""
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.random() * scale)

    return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()


def slow(test_case):
    """
    Decorator marking a test as slow.

    Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.

    """
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)