conftest.py 2.96 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
8
# tests directory-specific settings - this file is run automatically by pytest before any tests are run

import sys
aiss's avatar
aiss committed
9
10
import pytest
import os
aiss's avatar
aiss committed
11
from os.path import abspath, dirname, join
aiss's avatar
aiss committed
12
13
14
15
16
import torch
import warnings

# Set this environment variable for the T5 inference unittest(s) (e.g. google/t5-v1_1-small)
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
aiss's avatar
aiss committed
17
18
19
20
21

# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
aiss's avatar
aiss committed
22
23


aiss's avatar
aiss committed
24
25
26
27
28
29
30
def pytest_configure(config):
    config.option.color = "yes"
    config.option.durations = 0
    config.option.durations_min = 1
    config.option.verbose = True


aiss's avatar
aiss committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def pytest_addoption(parser):
    parser.addoption("--torch_ver", default=None, type=str)
    parser.addoption("--cuda_ver", default=None, type=str)


def validate_version(expected, found):
    version_depth = expected.count('.') + 1
    found = '.'.join(found.split('.')[:version_depth])
    return found == expected


@pytest.fixture(scope="session", autouse=True)
def check_environment(pytestconfig):
    expected_torch_version = pytestconfig.getoption("torch_ver")
    expected_cuda_version = pytestconfig.getoption("cuda_ver")
    if expected_torch_version is None:
        warnings.warn(
aiss's avatar
aiss committed
48
            "Running test without verifying torch version, please provide an expected torch version with --torch_ver")
aiss's avatar
aiss committed
49
50
51
52
53
54
    elif not validate_version(expected_torch_version, torch.__version__):
        pytest.exit(
            f"expected torch version {expected_torch_version} did not match found torch version {torch.__version__}",
            returncode=2)
    if expected_cuda_version is None:
        warnings.warn(
aiss's avatar
aiss committed
55
            "Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver")
aiss's avatar
aiss committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    elif not validate_version(expected_cuda_version, torch.version.cuda):
        pytest.exit(
            f"expected cuda version {expected_cuda_version} did not match found cuda version {torch.version.cuda}",
            returncode=2)


# Override of pytest "runtest" for DistributedTest class
# This hook is run before the default pytest_runtest_call
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
    # We want to use our own launching function for distributed tests
    if getattr(item.cls, "is_dist_test", False):
        dist_test_class = item.cls()
        dist_test_class(item._request)
        item.runtest = lambda: True  # Dummy function so test is not run twice


@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
    if getattr(fixturedef.func, "is_dist_fixture", False):
        #for val in dir(request):
        #    print(val.upper(), getattr(request, val), "\n")
        dist_fixture_class = fixturedef.func()
        dist_fixture_class(request)