Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
import torch import torch
import torch.nn as nn import torch.nn as nn
from basic import BiLinearLSR, BiConv2d, FixedRadiusNNGraph, RelativePositionMessage
import torch.nn.functional as F import torch.nn.functional as F
from basic import (
BiConv2d,
BiLinearLSR,
FixedRadiusNNGraph,
RelativePositionMessage,
)
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
class BiPointNetConv(nn.Module): class BiPointNetConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes, batch_size): def __init__(self, sizes, batch_size):
super(BiPointNetConv, self).__init__() super(BiPointNetConv, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(BiConv2d(sizes[i-1], sizes[i], 1)) self.conv.append(BiConv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1) h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -28,12 +38,12 @@ class BiPointNetConv(nn.Module): ...@@ -28,12 +38,12 @@ class BiPointNetConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
def group_all(self, pos, feat): def group_all(self, pos, feat):
''' """
Feature aggregation and pooling for the non-sampling layer Feature aggregation and pooling for the non-sampling layer
''' """
if feat is not None: if feat is not None:
h = torch.cat([pos, feat], 2) h = torch.cat([pos, feat], 2)
else: else:
...@@ -49,12 +59,21 @@ class BiPointNetConv(nn.Module): ...@@ -49,12 +59,21 @@ class BiPointNetConv(nn.Module):
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D] h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h return new_pos, h
class BiSAModule(nn.Module): class BiSAModule(nn.Module):
""" """
The Set Abstraction Layer The Set Abstraction Layer
""" """
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False): def __init__(
self,
npoints,
batch_size,
radius,
mlp_sizes,
n_neighbor=64,
group_all=False,
):
super(BiSAModule, self).__init__() super(BiSAModule, self).__init__()
self.group_all = group_all self.group_all = group_all
if not group_all: if not group_all:
...@@ -72,22 +91,30 @@ class BiSAModule(nn.Module): ...@@ -72,22 +91,30 @@ class BiSAModule(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim) feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
class BiPointNet2SSGCls(nn.Module): class BiPointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4): def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(BiPointNet2SSGCls, self).__init__() super(BiPointNet2SSGCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.sa_module1 = BiSAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128]) self.sa_module1 = BiSAModule(
self.sa_module2 = BiSAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256]) 512, batch_size, 0.2, [input_dims, 64, 64, 128]
self.sa_module3 = BiSAModule(None, batch_size, None, [256 + 3, 256, 512, 1024], )
group_all=True) self.sa_module2 = BiSAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = BiSAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = BiLinearLSR(1024, 512) self.mlp1 = BiLinearLSR(1024, 512)
self.bn1 = nn.BatchNorm1d(512) self.bn1 = nn.BatchNorm1d(512)
......
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from basic import BiLinear from basic import BiLinear
from torch.autograd import Variable
offset_map = {1024: -3.2041, 2048: -3.4025, 4096: -3.5836}
offset_map = {
1024: -3.2041,
2048: -3.4025,
4096: -3.5836
}
class Conv1d(nn.Module): class Conv1d(nn.Module):
def __init__(self, inplane, outplane, Linear): def __init__(self, inplane, outplane, Linear):
...@@ -38,9 +35,16 @@ class EmaMaxPool(nn.Module): ...@@ -38,9 +35,16 @@ class EmaMaxPool(nn.Module):
x = torch.max(x, 2, keepdim=True)[0] - 0.3 x = torch.max(x, 2, keepdim=True)[0] - 0.3
return x return x
class BiPointNetCls(nn.Module): class BiPointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64, def __init__(
use_transform=True, Linear=BiLinear): self,
output_classes,
input_dims=3,
conv1_dim=64,
use_transform=True,
Linear=BiLinear,
):
super(BiPointNetCls, self).__init__() super(BiPointNetCls, self).__init__()
self.input_dims = input_dims self.input_dims = input_dims
self.conv1 = nn.ModuleList() self.conv1 = nn.ModuleList()
...@@ -119,6 +123,7 @@ class BiPointNetCls(nn.Module): ...@@ -119,6 +123,7 @@ class BiPointNetCls(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
return out return out
class TransformNet(nn.Module): class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear): def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear):
super(TransformNet, self).__init__() super(TransformNet, self).__init__()
...@@ -153,7 +158,7 @@ class TransformNet(nn.Module): ...@@ -153,7 +158,7 @@ class TransformNet(nn.Module):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
h = F.relu(h) h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len) h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2): for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h) h = mlp(h)
...@@ -162,8 +167,14 @@ class TransformNet(nn.Module): ...@@ -162,8 +167,14 @@ class TransformNet(nn.Module):
out = self.mlp_out(h) out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32))) iden = Variable(
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1) torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda: if out.is_cuda:
iden = iden.cuda() iden = iden.cuda()
out = out + iden out = out + iden
......
from bipointnet_cls import BiPointNetCls
from bipointnet2 import BiPointNet2SSGCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse import argparse
import os import os
import urllib import urllib
import tqdm
from functools import partial from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl import dgl
from torch.utils.data import DataLoader import provider
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from bipointnet2 import BiPointNet2SSGCls
from bipointnet_cls import BiPointNetCls
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader
from torch.utils.data import DataLoader
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
# from dataset import ModelNet # from dataset import ModelNet
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='bipointnet') parser.add_argument("--model", type=str, default="bipointnet")
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=200) parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument('--num-workers', type=int, default=0) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip' data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename) download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join( local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled') get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download(
download_path, verify_ssl=False) "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(download_path) as z: with ZipFile(download_path) as z:
z.extractall(path=get_download_dir()) z.extractall(path=get_download_dir())
...@@ -48,11 +55,11 @@ CustomDataLoader = partial( ...@@ -48,11 +55,11 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
net.train() net.train()
total_loss = 0 total_loss = 0
...@@ -64,8 +71,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -64,8 +71,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq: for data, label in tq:
data = data.data.numpy() data = data.data.numpy()
data = provider.random_point_dropout(data) data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data) data = torch.tensor(data)
...@@ -88,9 +94,12 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -88,9 +94,12 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % (total_loss / num_batches), {
'AvgAcc': '%.5f' % (total_correct / count)}) "AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
scheduler.step() scheduler.step()
...@@ -113,17 +122,16 @@ def evaluate(net, test_loader, dev): ...@@ -113,17 +122,16 @@ def evaluate(net, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
'AvgAcc': '%.5f' % (total_correct / count)})
return total_correct / count return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'bipointnet': if args.model == "bipointnet":
net = BiPointNetCls(40, input_dims=6) net = BiPointNetCls(40, input_dims=6)
elif args.model == 'bipointnet2_ssg': elif args.model == "bipointnet2_ssg":
net = BiPointNet2SSGCls(40, batch_size, input_dims=6) net = BiPointNet2SSGCls(40, batch_size, input_dims=6)
net = net.to(dev) net = net.to(dev)
...@@ -134,23 +142,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4) ...@@ -134,23 +142,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7) scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train') train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split='test') test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0 best_test_acc = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev) train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0: if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch) print("Epoch #%d Testing" % epoch)
test_acc = evaluate(net, test_loader, dev) test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc: if test_acc > best_test_acc:
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % ( print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
test_acc, best_test_acc))
...@@ -8,11 +8,11 @@ import torch.nn as nn ...@@ -8,11 +8,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from model import Model, compute_loss
from modelnet import ModelNet
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir from dgl.data.utils import download, get_download_dir
from model import compute_loss, Model
from modelnet import ModelNet
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="") parser.add_argument("--dataset-path", type=str, default="")
......
...@@ -2,14 +2,14 @@ import json ...@@ -2,14 +2,14 @@ import json
import os import os
from zipfile import ZipFile from zipfile import ZipFile
import dgl
import numpy as np import numpy as np
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
""" """
......
...@@ -7,12 +7,12 @@ import provider ...@@ -7,12 +7,12 @@ import provider
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader from ModelNetDataLoader import ModelNetDataLoader
from pct import PointTransformerCLS from pct import PointTransformerCLS
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -54,7 +54,6 @@ CustomDataLoader = partial( ...@@ -54,7 +54,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
net.train() net.train()
total_loss = 0 total_loss = 0
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import time import time
from functools import partial from functools import partial
import dgl
import numpy as np import numpy as np
import provider import provider
import torch import torch
...@@ -11,8 +13,6 @@ from pct import PartSegLoss, PointTransformerSeg ...@@ -11,8 +13,6 @@ from pct import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="") parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="") parser.add_argument("--load-model-path", type=str, default="")
......
...@@ -2,14 +2,14 @@ import json ...@@ -2,14 +2,14 @@ import json
import os import os
from zipfile import ZipFile from zipfile import ZipFile
import dgl
import numpy as np import numpy as np
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
""" """
...@@ -270,7 +269,6 @@ class TransitionUp(nn.Module): ...@@ -270,7 +269,6 @@ class TransitionUp(nn.Module):
""" """
def __init__(self, dim1, dim2, dim_out): def __init__(self, dim1, dim2, dim_out):
super(TransitionUp, self).__init__() super(TransitionUp, self).__init__()
self.fc1 = nn.Sequential( self.fc1 = nn.Sequential(
nn.Linear(dim1, dim_out), nn.Linear(dim1, dim_out),
......
import numpy as np import numpy as np
import torch import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance from helper import index_points, square_distance, TransitionDown, TransitionUp
from torch import nn from torch import nn
""" """
......
...@@ -7,12 +7,12 @@ import provider ...@@ -7,12 +7,12 @@ import provider
import torch import torch
import torch.nn as nn import torch.nn as nn
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader from ModelNetDataLoader import ModelNetDataLoader
from point_transformer import PointTransformerCLS from point_transformer import PointTransformerCLS
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -55,7 +55,6 @@ CustomDataLoader = partial( ...@@ -55,7 +55,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
net.train() net.train()
total_loss = 0 total_loss = 0
......
...@@ -2,6 +2,8 @@ import argparse ...@@ -2,6 +2,8 @@ import argparse
import time import time
from functools import partial from functools import partial
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -10,8 +12,6 @@ from point_transformer import PartSegLoss, PointTransformerSeg ...@@ -10,8 +12,6 @@ from point_transformer import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="") parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="") parser.add_argument("--load-model-path", type=str, default="")
......
...@@ -2,14 +2,14 @@ import json ...@@ -2,14 +2,14 @@ import json
import os import os
from zipfile import ZipFile from zipfile import ZipFile
import dgl
import numpy as np import numpy as np
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
import dgl
import dgl.function as fn
from dgl.geometry import ( from dgl.geometry import (
farthest_point_sampler, farthest_point_sampler,
) # dgl.geometry.pytorch -> dgl.geometry ) # dgl.geometry.pytorch -> dgl.geometry
from torch.autograd import Variable
""" """
Part of the code are adapted from Part of the code are adapted from
......
...@@ -3,20 +3,20 @@ import os ...@@ -3,20 +3,20 @@ import os
import urllib import urllib
from functools import partial from functools import partial
import dgl
import provider import provider
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader from ModelNetDataLoader import ModelNetDataLoader
from pointnet2 import PointNet2MSGCls, PointNet2SSGCls from pointnet2 import PointNet2MSGCls, PointNet2SSGCls
from pointnet_cls import PointNetCls from pointnet_cls import PointNetCls
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
...@@ -62,7 +62,6 @@ CustomDataLoader = partial( ...@@ -62,7 +62,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
net.train() net.train()
total_loss = 0 total_loss = 0
......
...@@ -4,20 +4,20 @@ import time ...@@ -4,20 +4,20 @@ import time
import urllib import urllib
from functools import partial from functools import partial
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from dgl.data.utils import download, get_download_dir
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
from pointnet_partseg import PartSegLoss, PointNetPartSeg from pointnet_partseg import PartSegLoss, PointNetPartSeg
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="pointnet") parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument("--dataset-path", type=str, default="") parser.add_argument("--dataset-path", type=str, default="")
...@@ -260,6 +260,8 @@ color_map = torch.tensor( ...@@ -260,6 +260,8 @@ color_map = torch.tensor(
[255, 105, 180], [255, 105, 180],
] ]
) )
# paint each point according to its pred # paint each point according to its pred
def paint(batched_points): def paint(batched_points):
B, N = batched_points.shape B, N = batched_points.shape
......
import torch
import dgl import dgl
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset
......
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import tqdm
import dgl.function as fn
import dgl.nn as dglnn
from dgl.dataloading import NeighborSampler, DataLoader
from dgl import apply_each from dgl import apply_each
from dgl.dataloading import DataLoader, NeighborSampler
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
class HeteroGAT(nn.Module): class HeteroGAT(nn.Module):
def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4): def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.HeteroGraphConv({ self.layers.append(
etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads) dglnn.HeteroGraphConv(
for etype in etypes})) {
self.layers.append(dglnn.HeteroGraphConv({ etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads) for etype in etypes
for etype in etypes})) }
self.layers.append(dglnn.HeteroGraphConv({ )
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads) )
for etype in etypes})) self.layers.append(
dglnn.HeteroGraphConv(
{
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes
}
)
)
self.layers.append(
dglnn.HeteroGraphConv(
{
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes
}
)
)
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(hid_size, out_size) # Should be HeteroLinear self.linear = nn.Linear(hid_size, out_size) # Should be HeteroLinear
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -32,19 +48,24 @@ class HeteroGAT(nn.Module): ...@@ -32,19 +48,24 @@ class HeteroGAT(nn.Module):
h = layer(block, h) h = layer(block, h)
# One thing is that h might return tensors with zero rows if the number of dst nodes # One thing is that h might return tensors with zero rows if the number of dst nodes
# of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case. # of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case.
h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])) h = apply_each(
h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])
)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = apply_each(h, F.relu) h = apply_each(h, F.relu)
h = apply_each(h, self.dropout) h = apply_each(h, self.dropout)
return self.linear(h['paper']) return self.linear(h["paper"])
def evaluate(model, dataloader, desc): def evaluate(model, dataloader, desc):
preds = [] preds = []
labels = [] labels = []
with torch.no_grad(): with torch.no_grad():
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc=desc): for input_nodes, output_nodes, blocks in tqdm.tqdm(
x = blocks[0].srcdata['feat'] dataloader, desc=desc
y = blocks[-1].dstdata['label']['paper'][:, 0] ):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]["paper"][:, 0]
y_hat = model(blocks, x) y_hat = model(blocks, x)
preds.append(y_hat.cpu()) preds.append(y_hat.cpu())
labels.append(y.cpu()) labels.append(y.cpu())
...@@ -53,6 +74,7 @@ def evaluate(model, dataloader, desc): ...@@ -53,6 +74,7 @@ def evaluate(model, dataloader, desc):
acc = MF.accuracy(preds, labels) acc = MF.accuracy(preds, labels)
return acc return acc
def train(train_loader, val_loader, test_loader, model): def train(train_loader, val_loader, test_loader, model):
# loss function and optimizer # loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
...@@ -62,9 +84,11 @@ def train(train_loader, val_loader, test_loader, model): ...@@ -62,9 +84,11 @@ def train(train_loader, val_loader, test_loader, model):
for epoch in range(10): for epoch in range(10):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(tqdm.tqdm(train_dataloader, desc="Train")): for it, (input_nodes, output_nodes, blocks) in enumerate(
x = blocks[0].srcdata['feat'] tqdm.tqdm(train_dataloader, desc="Train")
y = blocks[-1].dstdata['label']['paper'][:, 0] ):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]["paper"][:, 0]
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = loss_fcn(y_hat, y) loss = loss_fcn(y_hat, y)
opt.zero_grad() opt.zero_grad()
...@@ -72,51 +96,94 @@ def train(train_loader, val_loader, test_loader, model): ...@@ -72,51 +96,94 @@ def train(train_loader, val_loader, test_loader, model):
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
model.eval() model.eval()
val_acc = evaluate(model, val_dataloader, 'Val. ') val_acc = evaluate(model, val_dataloader, "Val. ")
test_acc = evaluate(model, test_dataloader, 'Test ') test_acc = evaluate(model, test_dataloader, "Test ")
print(f'Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}') print(
f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}"
)
if __name__ == '__main__': if __name__ == "__main__":
print(f'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules') print(
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') f"Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load and preprocess dataset # load and preprocess dataset
print('Loading data') print("Loading data")
dataset = DglNodePropPredDataset('ogbn-mag') dataset = DglNodePropPredDataset("ogbn-mag")
graph, labels = dataset[0] graph, labels = dataset[0]
graph.ndata['label'] = labels graph.ndata["label"] = labels
# add reverse edges in "cites" relation, and add reverse edge types for the rest etypes # add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
graph = dgl.AddReverse()(graph) graph = dgl.AddReverse()(graph)
# precompute the author, topic, and institution features # precompute the author, topic, and institution features
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes') graph.update_all(
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic') fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="rev_writes"
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with') )
graph.update_all(
fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="has_topic"
)
graph.update_all(
fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="affiliated_with"
)
# find train/val/test indexes # find train/val/test indexes
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, val_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
train_idx = apply_each(train_idx, lambda x: x.to(device)) train_idx = apply_each(train_idx, lambda x: x.to(device))
val_idx = apply_each(val_idx, lambda x: x.to(device)) val_idx = apply_each(val_idx, lambda x: x.to(device))
test_idx = apply_each(test_idx, lambda x: x.to(device)) test_idx = apply_each(test_idx, lambda x: x.to(device))
# create RGAT model # create RGAT model
in_size = graph.ndata['feat']['paper'].shape[1] in_size = graph.ndata["feat"]["paper"].shape[1]
out_size = dataset.num_classes out_size = dataset.num_classes
model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device) model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device)
# dataloader + model training + testing # dataloader + model training + testing
train_sampler = NeighborSampler([5, 5, 5], train_sampler = NeighborSampler(
prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, [5, 5, 5],
prefetch_labels={'paper': ['label']}) prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
val_sampler = NeighborSampler([10, 10, 10], prefetch_labels={"paper": ["label"]},
prefetch_node_feats={k: ['feat'] for k in graph.ntypes}, )
prefetch_labels={'paper': ['label']}) val_sampler = NeighborSampler(
train_dataloader = DataLoader(graph, train_idx, train_sampler, [10, 10, 10],
device=device, batch_size=1000, shuffle=True, prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) prefetch_labels={"paper": ["label"]},
val_dataloader = DataLoader(graph, val_idx, val_sampler, )
device=device, batch_size=1000, shuffle=False, train_dataloader = DataLoader(
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) graph,
test_dataloader = DataLoader(graph, test_idx, val_sampler, train_idx,
device=device, batch_size=1000, shuffle=False, train_sampler,
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available()) device=device,
batch_size=1000,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
val_dataloader = DataLoader(
graph,
val_idx,
val_sampler,
device=device,
batch_size=1000,
shuffle=False,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
test_dataloader = DataLoader(
graph,
test_idx,
val_sampler,
device=device,
batch_size=1000,
shuffle=False,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
train(train_dataloader, val_dataloader, test_dataloader, model) train(train_dataloader, val_dataloader, test_dataloader, model)
...@@ -9,9 +9,9 @@ import numpy as np ...@@ -9,9 +9,9 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model import EntityClassify
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from model import EntityClassify
def main(args): def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment