"vscode:/vscode.git/clone" did not exist on "060783b7d2b7386e798a47fc19aaf0a9b3260545"
test_fid_inception.py 1.04 KB
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmgen.models.architectures import InceptionV3


class TestFIDInception:

    @classmethod
    def setup_class(cls):
        cls.load_fid_inception = False

    def test_fid_inception(self):
        inception = InceptionV3(load_fid_inception=self.load_fid_inception)
        imgs = torch.randn((2, 3, 256, 256))
        out = inception(imgs)[0]
        assert out.shape == (2, 2048, 1, 1)

        imgs = torch.randn((2, 3, 512, 512))
        out = inception(imgs)[0]
        assert out.shape == (2, 2048, 1, 1)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
    def test_fid_inception_cuda(self):
        inception = InceptionV3(
            load_fid_inception=self.load_fid_inception).cuda()
        imgs = torch.randn((2, 3, 256, 256)).cuda()
        out = inception(imgs)[0]
        assert out.shape == (2, 2048, 1, 1)

        imgs = torch.randn((2, 3, 512, 512)).cuda()
        out = inception(imgs)[0]
        assert out.shape == (2, 2048, 1, 1)