test_trace.py 633 Bytes
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
import pytest
import torch

5
from mmcv.utils import digit_version, is_jit_tracing
6
7
8


@pytest.mark.skipif(
9
    digit_version(torch.__version__) < digit_version('1.6.0'),
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
    reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():

    def foo(x):
        if is_jit_tracing():
            return x
        else:
            return x.tolist()

    x = torch.rand(3)
    # test without trace
    assert isinstance(foo(x), list)

    # test with trace
    traced_foo = torch.jit.trace(foo, (torch.rand(1), ))
    assert isinstance(traced_foo(x), torch.Tensor)