test_ops_utils.py 2.73 KB
Newer Older
Roman Shapovalov's avatar
Roman Shapovalov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest

import numpy as np
import torch

from common_testing import TestCaseMixin

from pytorch3d.ops import utils as oputil

class TestOpsUtils(TestCaseMixin, unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)
        np.random.seed(42)

    def test_wmean(self):
        device = torch.device("cuda:0")
        n_points = 20

        x = torch.rand(n_points, 3, device=device)
        weight = torch.rand(n_points, device=device)
        x_np = x.cpu().data.numpy()
        weight_np = weight.cpu().data.numpy()

        # test unweighted
        mean = oputil.wmean(x, keepdim=False)
        mean_gt = np.average(x_np, axis=-2)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test weighted
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test keepdim
        mean = oputil.wmean(x, weight=weight, keepdim=True)
        self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

        # test binary weigths
        mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test broadcasting
        x = torch.rand(10, n_points, 3, device=device)
        x_np = x.cpu().data.numpy()
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        weight = weight[None, None, :].repeat(3, 1, 1)
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

        # test failing broadcasting
        weight = torch.rand(x.shape[0], device=device)
        with self.assertRaises(ValueError) as context:
            oputil.wmean(x, weight=weight, keepdim=False)
        self.assertTrue("weights are not compatible" in str(context.exception))

        # test dim
        weight = torch.rand(x.shape[0], n_points, device=device)
        weight_np = np.tile(
            weight[:, :, None].cpu().data.numpy(),
            (1, 1, x_np.shape[-1]),
        )
        mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=0, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test dim tuple
        mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)