pointnet2_partseg.py 4.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from pointnet2 import SAModule, SAMSGModule, PointNet2FP

class PointNet2SSGPartSeg(nn.Module):
    def __init__(self, output_classes, batch_size, input_dims=6):
        super(PointNet2SSGPartSeg, self).__init__()
        #if normal_channel == true, input_dims = 6+3
        self.input_dims = input_dims

        self.sa_module1 = SAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128], n_neighbor=32)
        self.sa_module2 = SAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
        self.sa_module3 = SAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
                                   group_all=True)

        self.fp3 = PointNet2FP(1280, [256, 256])
        self.fp2 = PointNet2FP(384, [256, 128])
        # if normal_channel == true, 128+16+6+3
        self.fp1 = PointNet2FP(128+16+6, [128, 128, 128])

        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, output_classes, 1)

    def forward(self, x, cat_vec=None):
        if x.shape[-1] > 3:
            l0_pos = x[:, :, :3]
            l0_feat = x
        else:
            l0_pos = x
            l0_feat = x
        # Set Abstraction layers
        l1_pos, l1_feat = self.sa_module1(l0_pos, l0_feat)  # l1_feat: [B, N, D]
        l2_pos, l2_feat = self.sa_module2(l1_pos, l1_feat)
        l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat)  # [B, N, C], [B, D]
        # Feature Propagation layers
        l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1))  # l2_feat: [B, D, N]
        l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
        l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
        l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_feat)))
        out = self.drop1(feat)
        out = self.conv2(out)  # [B, output_classes, N]
        return out


class PointNet2MSGPartSeg(nn.Module):
    def __init__(self, output_classes, batch_size, input_dims=6):
        super(PointNet2MSGPartSeg, self).__init__()

        self.sa_msg_module1 = SAMSGModule(512, batch_size, [0.1, 0.2, 0.4], [32, 64, 128],
                                          [[input_dims, 32, 32, 64], [input_dims, 64, 64, 128],
                                           [input_dims, 64, 96, 128]])
        self.sa_msg_module2 = SAMSGModule(128, batch_size, [0.4, 0.8], [64, 128],
                                          [[128+128+64 +3, 128, 128, 256], [128+128+64 +3, 128, 196, 256]])
        self.sa_module3 = SAModule(None, batch_size, None, [512 + 3, 256, 512, 1024],
                                   group_all=True)

        self.fp3 = PointNet2FP(1536, [256, 256])
        self.fp2 = PointNet2FP(576, [256, 128])
        # if normal_channel == true, 150 + 3
        self.fp1 = PointNet2FP(150, [128, 128])

        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, output_classes, 1)

    def forward(self, x, cat_vec=None):
        if x.shape[-1] > 3:
            l0_pos = x[:, :, :3]
            l0_feat = x
        else:
            l0_pos = x
            l0_feat = x
        # Set Abstraction layers
        l1_pos, l1_feat = self.sa_msg_module1(l0_pos, l0_feat)
        l2_pos, l2_feat = self.sa_msg_module2(l1_pos, l1_feat)
        l3_pos, l3_feat = self.sa_module3(l2_pos, l2_feat)
        # Feature Propagation layers
        l2_feat = self.fp3(l2_pos, l3_pos, l2_feat, l3_feat.unsqueeze(1))
        l1_feat = self.fp2(l1_pos, l2_pos, l1_feat, l2_feat.permute(0, 2, 1))
        l0_feat = torch.cat([cat_vec.permute(0, 2, 1), l0_pos, l0_feat], 2)
        l0_feat = self.fp1(l0_pos, l1_pos, l0_feat, l1_feat.permute(0, 2, 1))
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_feat)))
        out = self.drop1(feat)
        out = self.conv2(out)
        return out