Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
import torch
import torch.nn as nn
from basic import BiLinearLSR, BiConv2d, FixedRadiusNNGraph, RelativePositionMessage
import torch.nn.functional as F
from basic import (
BiConv2d,
BiLinearLSR,
FixedRadiusNNGraph,
RelativePositionMessage,
)
from dgl.geometry import farthest_point_sampler
class BiPointNetConv(nn.Module):
'''
"""
Feature aggregation
'''
"""
def __init__(self, sizes, batch_size):
super(BiPointNetConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(BiConv2d(sizes[i-1], sizes[i], 1))
self.conv.append(BiConv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(self.batch_size, -1, shape[1], shape[2]).permute(0, 3, 2, 1)
shape = nodes.mailbox["agg_feat"].shape
h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
......@@ -28,12 +38,12 @@ class BiPointNetConv(nn.Module):
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
return {"new_feat": h}
def group_all(self, pos, feat):
'''
"""
Feature aggregation and pooling for the non-sampling layer
'''
"""
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
......@@ -49,12 +59,21 @@ class BiPointNetConv(nn.Module):
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h
class BiSAModule(nn.Module):
"""
The Set Abstraction Layer
"""
def __init__(self, npoints, batch_size, radius, mlp_sizes, n_neighbor=64,
group_all=False):
def __init__(
self,
npoints,
batch_size,
radius,
mlp_sizes,
n_neighbor=64,
group_all=False,
):
super(BiSAModule, self).__init__()
self.group_all = group_all
if not group_all:
......@@ -72,22 +91,30 @@ class BiSAModule(nn.Module):
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1, feat_dim)
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
class BiPointNet2SSGCls(nn.Module):
def __init__(self, output_classes, batch_size, input_dims=3, dropout_prob=0.4):
def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(BiPointNet2SSGCls, self).__init__()
self.input_dims = input_dims
self.sa_module1 = BiSAModule(512, batch_size, 0.2, [input_dims, 64, 64, 128])
self.sa_module2 = BiSAModule(128, batch_size, 0.4, [128 + 3, 128, 128, 256])
self.sa_module3 = BiSAModule(None, batch_size, None, [256 + 3, 256, 512, 1024],
group_all=True)
self.sa_module1 = BiSAModule(
512, batch_size, 0.2, [input_dims, 64, 64, 128]
)
self.sa_module2 = BiSAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = BiSAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = BiLinearLSR(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
......
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from basic import BiLinear
from torch.autograd import Variable
offset_map = {1024: -3.2041, 2048: -3.4025, 4096: -3.5836}
offset_map = {
1024: -3.2041,
2048: -3.4025,
4096: -3.5836
}
class Conv1d(nn.Module):
def __init__(self, inplane, outplane, Linear):
......@@ -38,9 +35,16 @@ class EmaMaxPool(nn.Module):
x = torch.max(x, 2, keepdim=True)[0] - 0.3
return x
class BiPointNetCls(nn.Module):
def __init__(self, output_classes, input_dims=3, conv1_dim=64,
use_transform=True, Linear=BiLinear):
def __init__(
self,
output_classes,
input_dims=3,
conv1_dim=64,
use_transform=True,
Linear=BiLinear,
):
super(BiPointNetCls, self).__init__()
self.input_dims = input_dims
self.conv1 = nn.ModuleList()
......@@ -119,6 +123,7 @@ class BiPointNetCls(nn.Module):
out = self.mlp_out(h)
return out
class TransformNet(nn.Module):
def __init__(self, input_dims=3, conv1_dim=64, Linear=BiLinear):
super(TransformNet, self).__init__()
......@@ -153,7 +158,7 @@ class TransformNet(nn.Module):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = self.maxpool(h).view(-1, self.pool_feat_len)
for mlp, bn in zip(self.mlp2, self.bn2):
h = mlp(h)
......@@ -162,8 +167,14 @@ class TransformNet(nn.Module):
out = self.mlp_out(h)
iden = Variable(torch.from_numpy(np.eye(self.input_dims).flatten().astype(np.float32)))
iden = iden.view(1, self.input_dims * self.input_dims).repeat(batch_size, 1)
iden = Variable(
torch.from_numpy(
np.eye(self.input_dims).flatten().astype(np.float32)
)
)
iden = iden.view(1, self.input_dims * self.input_dims).repeat(
batch_size, 1
)
if out.is_cuda:
iden = iden.cuda()
out = out + iden
......
from bipointnet_cls import BiPointNetCls
from bipointnet2 import BiPointNet2SSGCls
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import urllib
import tqdm
from functools import partial
from dgl.data.utils import download, get_download_dir
import dgl
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import provider
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from bipointnet2 import BiPointNet2SSGCls
from bipointnet_cls import BiPointNetCls
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader
from torch.utils.data import DataLoader
torch.backends.cudnn.enabled = False
# from dataset import ModelNet
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='bipointnet')
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=200)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument("--model", type=str, default="bipointnet")
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=200)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
download(
"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
......@@ -48,11 +55,11 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
......@@ -64,8 +71,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
......@@ -88,9 +94,12 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix(
{
"AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
scheduler.step()
......@@ -113,17 +122,16 @@ def evaluate(net, test_loader, dev):
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
return total_correct / count
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.model == 'bipointnet':
if args.model == "bipointnet":
net = BiPointNetCls(40, input_dims=6)
elif args.model == 'bipointnet2_ssg':
elif args.model == "bipointnet2_ssg":
net = BiPointNet2SSGCls(40, batch_size, input_dims=6)
net = net.to(dev)
......@@ -134,23 +142,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.7)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0
for epoch in range(args.num_epochs):
train(net, opt, scheduler, train_loader, dev)
if (epoch + 1) % 1 == 0:
print('Epoch #%d Testing' % epoch)
print("Epoch #%d Testing" % epoch)
test_acc = evaluate(net, test_loader, dev)
if test_acc > best_test_acc:
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
......@@ -8,11 +8,11 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from model import Model, compute_loss
from modelnet import ModelNet
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
from model import compute_loss, Model
from modelnet import ModelNet
from torch.utils.data import DataLoader
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="")
......
......@@ -2,14 +2,14 @@ import json
import os
from zipfile import ZipFile
import dgl
import numpy as np
import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
"""
......
......@@ -7,12 +7,12 @@ import provider
import torch
import torch.nn as nn
import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader
from pct import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
......@@ -54,7 +54,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
......
......@@ -2,6 +2,8 @@ import argparse
import time
from functools import partial
import dgl
import numpy as np
import provider
import torch
......@@ -11,8 +13,6 @@ from pct import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet
from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
......
......@@ -2,14 +2,14 @@ import json
import os
from zipfile import ZipFile
import dgl
import numpy as np
import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
"""
......@@ -270,7 +269,6 @@ class TransitionUp(nn.Module):
"""
def __init__(self, dim1, dim2, dim_out):
super(TransitionUp, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(dim1, dim_out),
......
import numpy as np
import torch
from helper import TransitionDown, TransitionUp, index_points, square_distance
from helper import index_points, square_distance, TransitionDown, TransitionUp
from torch import nn
"""
......
......@@ -7,12 +7,12 @@ import provider
import torch
import torch.nn as nn
import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader
from point_transformer import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
......@@ -55,7 +55,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
......
......@@ -2,6 +2,8 @@ import argparse
import time
from functools import partial
import dgl
import numpy as np
import torch
import torch.optim as optim
......@@ -10,8 +12,6 @@ from point_transformer import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet
from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
......
......@@ -2,14 +2,14 @@ import json
import os
from zipfile import ZipFile
import dgl
import numpy as np
import tqdm
from dgl.data.utils import download, get_download_dir
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
......
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import dgl
import dgl.function as fn
from dgl.geometry import (
farthest_point_sampler,
) # dgl.geometry.pytorch -> dgl.geometry
from torch.autograd import Variable
"""
Part of the code are adapted from
......
......@@ -3,20 +3,20 @@ import os
import urllib
from functools import partial
import dgl
import provider
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from dgl.data.utils import download, get_download_dir
from ModelNetDataLoader import ModelNetDataLoader
from pointnet2 import PointNet2MSGCls, PointNet2SSGCls
from pointnet_cls import PointNetCls
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
......@@ -62,7 +62,6 @@ CustomDataLoader = partial(
def train(net, opt, scheduler, train_loader, dev):
net.train()
total_loss = 0
......
......@@ -4,20 +4,20 @@ import time
import urllib
from functools import partial
import dgl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from dgl.data.utils import download, get_download_dir
from pointnet2_partseg import PointNet2MSGPartSeg, PointNet2SSGPartSeg
from pointnet_partseg import PartSegLoss, PointNetPartSeg
from ShapeNet import ShapeNet
from torch.utils.data import DataLoader
import dgl
from dgl.data.utils import download, get_download_dir
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="pointnet")
parser.add_argument("--dataset-path", type=str, default="")
......@@ -260,6 +260,8 @@ color_map = torch.tensor(
[255, 105, 180],
]
)
# paint each point according to its pred
def paint(batched_points):
B, N = batched_points.shape
......
import torch
import dgl
import torch
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
......
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.function as fn
import dgl.nn as dglnn
from dgl.dataloading import NeighborSampler, DataLoader
import tqdm
from dgl import apply_each
from dgl.dataloading import DataLoader, NeighborSampler
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
class HeteroGAT(nn.Module):
def __init__(self, etypes, in_size, hid_size, out_size, n_heads=4):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.HeteroGraphConv({
etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
for etype in etypes}))
self.layers.append(dglnn.HeteroGraphConv({
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes}))
self.layers.append(dglnn.HeteroGraphConv({
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes}))
self.layers.append(
dglnn.HeteroGraphConv(
{
etype: dglnn.GATConv(in_size, hid_size // n_heads, n_heads)
for etype in etypes
}
)
)
self.layers.append(
dglnn.HeteroGraphConv(
{
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes
}
)
)
self.layers.append(
dglnn.HeteroGraphConv(
{
etype: dglnn.GATConv(hid_size, hid_size // n_heads, n_heads)
for etype in etypes
}
)
)
self.dropout = nn.Dropout(0.5)
self.linear = nn.Linear(hid_size, out_size) # Should be HeteroLinear
self.linear = nn.Linear(hid_size, out_size) # Should be HeteroLinear
def forward(self, blocks, x):
h = x
......@@ -32,19 +48,24 @@ class HeteroGAT(nn.Module):
h = layer(block, h)
# One thing is that h might return tensors with zero rows if the number of dst nodes
# of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case.
h = apply_each(h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2]))
h = apply_each(
h, lambda x: x.view(x.shape[0], x.shape[1] * x.shape[2])
)
if l != len(self.layers) - 1:
h = apply_each(h, F.relu)
h = apply_each(h, self.dropout)
return self.linear(h['paper'])
return self.linear(h["paper"])
def evaluate(model, dataloader, desc):
preds = []
labels = []
with torch.no_grad():
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc=desc):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']['paper'][:, 0]
for input_nodes, output_nodes, blocks in tqdm.tqdm(
dataloader, desc=desc
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]["paper"][:, 0]
y_hat = model(blocks, x)
preds.append(y_hat.cpu())
labels.append(y.cpu())
......@@ -53,6 +74,7 @@ def evaluate(model, dataloader, desc):
acc = MF.accuracy(preds, labels)
return acc
def train(train_loader, val_loader, test_loader, model):
# loss function and optimizer
loss_fcn = nn.CrossEntropyLoss()
......@@ -62,9 +84,11 @@ def train(train_loader, val_loader, test_loader, model):
for epoch in range(10):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(tqdm.tqdm(train_dataloader, desc="Train")):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']['paper'][:, 0]
for it, (input_nodes, output_nodes, blocks) in enumerate(
tqdm.tqdm(train_dataloader, desc="Train")
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]["paper"][:, 0]
y_hat = model(blocks, x)
loss = loss_fcn(y_hat, y)
opt.zero_grad()
......@@ -72,51 +96,94 @@ def train(train_loader, val_loader, test_loader, model):
opt.step()
total_loss += loss.item()
model.eval()
val_acc = evaluate(model, val_dataloader, 'Val. ')
test_acc = evaluate(model, test_dataloader, 'Test ')
print(f'Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}')
val_acc = evaluate(model, val_dataloader, "Val. ")
test_acc = evaluate(model, test_dataloader, "Test ")
print(
f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}"
)
if __name__ == '__main__':
print(f'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == "__main__":
print(
f"Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load and preprocess dataset
print('Loading data')
dataset = DglNodePropPredDataset('ogbn-mag')
print("Loading data")
dataset = DglNodePropPredDataset("ogbn-mag")
graph, labels = dataset[0]
graph.ndata['label'] = labels
graph.ndata["label"] = labels
# add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
graph = dgl.AddReverse()(graph)
# precompute the author, topic, and institution features
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='rev_writes')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='has_topic')
graph.update_all(fn.copy_u('feat', 'm'), fn.mean('m', 'feat'), etype='affiliated_with')
graph.update_all(
fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="rev_writes"
)
graph.update_all(
fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="has_topic"
)
graph.update_all(
fn.copy_u("feat", "m"), fn.mean("m", "feat"), etype="affiliated_with"
)
# find train/val/test indexes
split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
train_idx, val_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
train_idx = apply_each(train_idx, lambda x: x.to(device))
val_idx = apply_each(val_idx, lambda x: x.to(device))
test_idx = apply_each(test_idx, lambda x: x.to(device))
# create RGAT model
in_size = graph.ndata['feat']['paper'].shape[1]
in_size = graph.ndata["feat"]["paper"].shape[1]
out_size = dataset.num_classes
model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device)
# dataloader + model training + testing
train_sampler = NeighborSampler([5, 5, 5],
prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
prefetch_labels={'paper': ['label']})
val_sampler = NeighborSampler([10, 10, 10],
prefetch_node_feats={k: ['feat'] for k in graph.ntypes},
prefetch_labels={'paper': ['label']})
train_dataloader = DataLoader(graph, train_idx, train_sampler,
device=device, batch_size=1000, shuffle=True,
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
val_dataloader = DataLoader(graph, val_idx, val_sampler,
device=device, batch_size=1000, shuffle=False,
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
test_dataloader = DataLoader(graph, test_idx, val_sampler,
device=device, batch_size=1000, shuffle=False,
drop_last=False, num_workers=0, use_uva=torch.cuda.is_available())
train_sampler = NeighborSampler(
[5, 5, 5],
prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
prefetch_labels={"paper": ["label"]},
)
val_sampler = NeighborSampler(
[10, 10, 10],
prefetch_node_feats={k: ["feat"] for k in graph.ntypes},
prefetch_labels={"paper": ["label"]},
)
train_dataloader = DataLoader(
graph,
train_idx,
train_sampler,
device=device,
batch_size=1000,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
val_dataloader = DataLoader(
graph,
val_idx,
val_sampler,
device=device,
batch_size=1000,
shuffle=False,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
test_dataloader = DataLoader(
graph,
test_idx,
val_sampler,
device=device,
batch_size=1000,
shuffle=False,
drop_last=False,
num_workers=0,
use_uva=torch.cuda.is_available(),
)
train(train_dataloader, val_dataloader, test_dataloader, model)
......@@ -9,9 +9,9 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from model import EntityClassify
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from model import EntityClassify
def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment