test_mink_resnet.py 1.8 KB
Newer Older
VVsssssk's avatar
VVsssssk committed
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch

6
from mmdet3d.registry import MODELS
VVsssssk's avatar
VVsssssk committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


def test_mink_resnet():
    if not torch.cuda.is_available():
        pytest.skip('test requires GPU and torch+cuda')

    try:
        import MinkowskiEngine as ME
    except ImportError:
        pytest.skip('test requires MinkowskiEngine installation')

    coordinates, features = [], []
    np.random.seed(42)
    # batch of 2 point clouds
    for i in range(2):
        c = torch.from_numpy(np.random.rand(500, 3) * 100)
        coordinates.append(c.float().cuda())
        f = torch.from_numpy(np.random.rand(500, 3))
        features.append(f.float().cuda())
    tensor_coordinates, tensor_features = ME.utils.sparse_collate(
        coordinates, features)
    x = ME.SparseTensor(
        features=tensor_features, coordinates=tensor_coordinates)

    # MinkResNet34 with 4 outputs
    cfg = dict(type='MinkResNet', depth=34, in_channels=3)
33
    self = MODELS.build(cfg).cuda()
VVsssssk's avatar
VVsssssk committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    self.init_weights()

    y = self(x)
    assert len(y) == 4
    assert y[0].F.shape == torch.Size([900, 64])
    assert y[0].tensor_stride[0] == 8
    assert y[1].F.shape == torch.Size([472, 128])
    assert y[1].tensor_stride[0] == 16
    assert y[2].F.shape == torch.Size([105, 256])
    assert y[2].tensor_stride[0] == 32
    assert y[3].F.shape == torch.Size([16, 512])
    assert y[3].tensor_stride[0] == 64

    # MinkResNet50 with 2 outputs
    cfg = dict(
        type='MinkResNet', depth=34, in_channels=3, num_stages=2, pool=False)
50
    self = MODELS.build(cfg).cuda()
VVsssssk's avatar
VVsssssk committed
51
52
53
54
55
56
57
58
    self.init_weights()

    y = self(x)
    assert len(y) == 2
    assert y[0].F.shape == torch.Size([985, 64])
    assert y[0].tensor_stride[0] == 4
    assert y[1].F.shape == torch.Size([900, 128])
    assert y[1].tensor_stride[0] == 8