check_v2_dataset_warnings.py 586 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import pytest


def test_warns_if_imported_from_datasets(mocker):
    mocker.patch("torchvision._WARN_ABOUT_BETA_TRANSFORMS", return_value=True)

    import torchvision

    with pytest.warns(UserWarning, match=torchvision._BETA_TRANSFORMS_WARNING):
        from torchvision.datasets import wrap_dataset_for_transforms_v2

        assert callable(wrap_dataset_for_transforms_v2)


@pytest.mark.filterwarnings("error")
def test_no_warns_if_imported_from_datasets():
    from torchvision.datasets import wrap_dataset_for_transforms_v2

    assert callable(wrap_dataset_for_transforms_v2)