test_forward_pass.py 2.33 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform


class TestGenericModel(unittest.TestCase):
    def test_gm(self):
        # Simple test of a forward pass of the default GenericModel.
        device = torch.device("cuda:1")
        expand_args_fields(GenericModel)
        model = GenericModel()
        model.to(device)

        n_train_cameras = 2
        R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
        cameras = PerspectiveCameras(R=R, T=T, device=device)

        # TODO: make these default to None?
        defaulted_args = {
            "fg_probability": None,
            "depth_map": None,
            "mask_crop": None,
            "sequence_name": None,
        }

        with self.assertWarnsRegex(UserWarning, "No main objective found"):
            model(
                camera=cameras,
                evaluation_mode=EvaluationMode.TRAINING,
                **defaulted_args,
                image_rgb=None,
            )
        target_image_rgb = torch.rand(
            (n_train_cameras, 3, model.render_image_height, model.render_image_width),
            device=device,
        )
        train_preds = model(
            camera=cameras,
            evaluation_mode=EvaluationMode.TRAINING,
            image_rgb=target_image_rgb,
            **defaulted_args,
        )
        self.assertGreater(train_preds["objective"].item(), 0)

        model.eval()
        with torch.no_grad():
            # TODO: perhaps this warning should be skipped in eval mode?
            with self.assertWarnsRegex(UserWarning, "No main objective found"):
                eval_preds = model(
                    camera=cameras[0],
                    **defaulted_args,
                    image_rgb=None,
                )
        self.assertEqual(
            eval_preds["images_render"].shape,
            (1, 3, model.render_image_height, model.render_image_width),
        )