"torchvision/vscode:/vscode.git/clone" did not exist on "66ed6937a8325b26a7dea4bf8d1a84195f5d544a"
test_common_linear_with_repeat.py 1.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
12
13
from .common_testing import TestCaseMixin

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)

    def test_simple(self):
        x = torch.rand(4, 6, 7, 3)
        y = torch.rand(4, 6, 4)

        linear = torch.nn.Linear(7, 8)
        torch.nn.init.xavier_uniform_(linear.weight.data)
        linear.bias.data.uniform_()
        equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
        expected = linear.forward(equivalent)

        linear_with_repeat = LinearWithRepeat(7, 8)
        linear_with_repeat.load_state_dict(linear.state_dict())
        actual = linear_with_repeat.forward((x, y))
        self.assertClose(actual, expected, rtol=1e-4)