test_trace.py 612 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from distutils.version import LooseVersion

import pytest
import torch

from mmcv.utils import is_jit_tracing


@pytest.mark.skipif(
    LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
    reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():

    def foo(x):
        if is_jit_tracing():
            return x
        else:
            return x.tolist()

    x = torch.rand(3)
    # test without trace
    assert isinstance(foo(x), list)

    # test with trace
    traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
    assert isinstance(traced_foo(x), torch.Tensor)