Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
......@@ -6,7 +6,8 @@ import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv
#Semantic attention in the metapath-based aggregation (the same as that in the HAN)
# Semantic attention in the metapath-based aggregation (the same as that in the HAN)
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
......@@ -14,179 +15,232 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def forward(self, z):
'''
"""
Shape of z: (N, M , D*K)
N: number of nodes
M: number of metapath patterns
D: hidden_size
K: number of heads
'''
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
"""
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
return (beta * z).sum(1) # (N, D * K)
#Metapath-based aggregation (the same as the HANLayer)
# Metapath-based aggregation (the same as the HANLayer)
class HANLayer(nn.Module):
def __init__(self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout):
def __init__(
self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(len(meta_path_patterns)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.meta_path_patterns = list(tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns)
self.gat_layers.append(
GATConv(
in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.meta_path_patterns = list(
tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns
)
self._cached_graph = None
self._cached_coalesced_graph = {}
def forward(self, g, h):
semantic_embeddings = []
#obtain metapath reachable graph
# obtain metapath reachable graph
if self._cached_graph is None or self._cached_graph is not g:
self._cached_graph = g
self._cached_coalesced_graph.clear()
for meta_path_pattern in self.meta_path_patterns:
self._cached_coalesced_graph[meta_path_pattern] = dgl.metapath_reachable_graph(
g, meta_path_pattern)
self._cached_coalesced_graph[
meta_path_pattern
] = dgl.metapath_reachable_graph(g, meta_path_pattern)
for i, meta_path_pattern in enumerate(self.meta_path_patterns):
new_g = self._cached_coalesced_graph[meta_path_pattern]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
#Relational neighbor aggregation
# Relational neighbor aggregation
class RelationalAGG(nn.Module):
def __init__(self, g, in_size, out_size, dropout=0.1):
super(RelationalAGG, self).__init__()
self.in_size = in_size
self.out_size = out_size
#Transform weights for different types of edges
self.W_T = nn.ModuleDict({
name : nn.Linear(in_size, out_size, bias = False) for name in g.etypes
})
# Transform weights for different types of edges
self.W_T = nn.ModuleDict(
{
name: nn.Linear(in_size, out_size, bias=False)
for name in g.etypes
}
)
#Attention weights for different types of edges
self.W_A = nn.ModuleDict({
name : nn.Linear(out_size, 1, bias = False) for name in g.etypes
})
# Attention weights for different types of edges
self.W_A = nn.ModuleDict(
{name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}
)
#layernorm
# layernorm
self.layernorm = nn.LayerNorm(out_size)
#dropout layer
# dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, g, feat_dict):
funcs={}
funcs = {}
for srctype, etype, dsttype in g.canonical_etypes:
g.nodes[dsttype].data['h'] = feat_dict[dsttype] #nodes' original feature
g.nodes[srctype].data['h'] = feat_dict[srctype]
g.nodes[srctype].data['t_h'] = self.W_T[etype](feat_dict[srctype]) #src nodes' transformed feature
#compute the attention numerator (exp)
g.apply_edges(fn.u_mul_v('t_h','h','x'),etype=etype)
g.edges[etype].data['x'] = torch.exp(self.W_A[etype](g.edges[etype].data['x']))
#first update to compute the attention denominator (\sum exp)
funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att'))
g.multi_update_all(funcs, 'sum')
funcs={}
g.nodes[dsttype].data["h"] = feat_dict[
dsttype
] # nodes' original feature
g.nodes[srctype].data["h"] = feat_dict[srctype]
g.nodes[srctype].data["t_h"] = self.W_T[etype](
feat_dict[srctype]
) # src nodes' transformed feature
# compute the attention numerator (exp)
g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype)
g.edges[etype].data["x"] = torch.exp(
self.W_A[etype](g.edges[etype].data["x"])
)
# first update to compute the attention denominator (\sum exp)
funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att"))
g.multi_update_all(funcs, "sum")
funcs = {}
for srctype, etype, dsttype in g.canonical_etypes:
g.apply_edges(fn.e_div_v('x', 'att', 'att'),etype=etype) #compute attention weights (numerator/denominator)
funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h')) #\sum(h0*att) -> h1
#second update to obtain h1
g.multi_update_all(funcs, 'sum')
#apply activation, layernorm, and dropout
feat_dict={}
g.apply_edges(
fn.e_div_v("x", "att", "att"), etype=etype
) # compute attention weights (numerator/denominator)
funcs[etype] = (
fn.u_mul_e("h", "att", "m"),
fn.sum("m", "h"),
) # \sum(h0*att) -> h1
# second update to obtain h1
g.multi_update_all(funcs, "sum")
# apply activation, layernorm, and dropout
feat_dict = {}
for ntype in g.ntypes:
feat_dict[ntype] = self.dropout(self.layernorm(F.relu_(g.nodes[ntype].data['h']))) #apply activation, layernorm, and dropout
feat_dict[ntype] = self.dropout(
self.layernorm(F.relu_(g.nodes[ntype].data["h"]))
) # apply activation, layernorm, and dropout
return feat_dict
class TAHIN(nn.Module):
def __init__(self, g, meta_path_patterns, in_size, out_size, num_heads, dropout):
def __init__(
self, g, meta_path_patterns, in_size, out_size, num_heads, dropout
):
super(TAHIN, self).__init__()
#embeddings for different types of nodes, h0
# embeddings for different types of nodes, h0
self.initializer = nn.init.xavier_uniform_
self.feature_dict = nn.ParameterDict({
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
})
self.feature_dict = nn.ParameterDict(
{
ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
#relational neighbor aggregation, this produces h1
# relational neighbor aggregation, this produces h1
self.RelationalAGG = RelationalAGG(g, in_size, out_size)
#metapath-based aggregation modules for user and item, this produces h2
self.meta_path_patterns = meta_path_patterns
#one HANLayer for user, one HANLayer for item
self.hans = nn.ModuleDict({
key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items()
})
#layers to combine h0, h1, and h2
#used to update node embeddings
self.user_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
self.user_layer2 = nn.Linear(2*out_size, out_size, bias=True)
self.item_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
self.item_layer2 = nn.Linear(2*out_size, out_size, bias=True)
#layernorm
# metapath-based aggregation modules for user and item, this produces h2
self.meta_path_patterns = meta_path_patterns
# one HANLayer for user, one HANLayer for item
self.hans = nn.ModuleDict(
{
key: HANLayer(value, in_size, out_size, num_heads, dropout)
for key, value in self.meta_path_patterns.items()
}
)
# layers to combine h0, h1, and h2
# used to update node embeddings
self.user_layer1 = nn.Linear(
(num_heads + 1) * out_size, out_size, bias=True
)
self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
self.item_layer1 = nn.Linear(
(num_heads + 1) * out_size, out_size, bias=True
)
self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
# layernorm
self.layernorm = nn.LayerNorm(out_size)
#network to score the node pairs
# network to score the node pairs
self.pred = nn.Linear(out_size, out_size)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(out_size, 1)
def forward(self, g, user_key, item_key, user_idx, item_idx):
#relational neighbor aggregation, h1
# relational neighbor aggregation, h1
h1 = self.RelationalAGG(g, self.feature_dict)
#metapath-based aggregation, h2
# metapath-based aggregation, h2
h2 = {}
for key in self.meta_path_patterns.keys():
h2[key] = self.hans[key](g, self.feature_dict[key])
#update node embeddings
# update node embeddings
user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
user_emb = self.user_layer1(user_emb)
item_emb = self.item_layer1(item_emb)
user_emb = self.user_layer2(torch.cat((user_emb, self.feature_dict[user_key]), 1))
item_emb = self.item_layer2(torch.cat((item_emb, self.feature_dict[item_key]), 1))
user_emb = self.user_layer2(
torch.cat((user_emb, self.feature_dict[user_key]), 1)
)
item_emb = self.item_layer2(
torch.cat((item_emb, self.feature_dict[item_key]), 1)
)
#Relu
# Relu
user_emb = F.relu_(user_emb)
item_emb = F.relu_(item_emb)
#layer norm
# layer norm
user_emb = self.layernorm(user_emb)
item_emb = self.layernorm(item_emb)
#obtain users/items embeddings and their interactions
# obtain users/items embeddings and their interactions
user_feat = user_emb[user_idx]
item_feat = item_emb[item_idx]
interaction = user_feat*item_feat
interaction = user_feat * item_feat
#score the node pairs
# score the node pairs
pred = self.pred(interaction)
pred = self.dropout(pred) #dropout
pred = self.dropout(pred) # dropout
pred = self.fc(pred)
pred = torch.sigmoid(pred)
return pred.squeeze(1)
\ No newline at end of file
This diff is collapsed.
import torch
import argparse
import pickle as pkl
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import dgl
import numpy as np
import pickle as pkl
import argparse
from data_loader import load_data
from TAHIN import TAHIN
from utils import evaluate_auc, evaluate_acc, evaluate_f1_score, evaluate_logloss
from utils import (
evaluate_acc,
evaluate_auc,
evaluate_f1_score,
evaluate_logloss,
)
import dgl
def main(args):
#step 1: Check device
# step 1: Check device
if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
device = "cuda:{}".format(args.gpu)
else:
device = 'cpu'
#step 2: Load data
g, train_loader, eval_loader, test_loader, meta_paths, user_key, item_key = load_data(args.dataset, args.batch, args.num_workers, args.path)
device = "cpu"
# step 2: Load data
(
g,
train_loader,
eval_loader,
test_loader,
meta_paths,
user_key,
item_key,
) = load_data(args.dataset, args.batch, args.num_workers, args.path)
g = g.to(device)
print('Data loaded.')
print("Data loaded.")
#step 3: Create model and training components
# step 3: Create model and training components
model = TAHIN(
g, meta_paths, args.in_size, args.out_size, args.num_heads, args.dropout
)
)
model = model.to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
print('Model created.')
print("Model created.")
#step 4: Training
print('Start training.')
# step 4: Training
print("Start training.")
best_acc = 0.0
kill_cnt = 0
for epoch in range(args.epochs):
......@@ -63,33 +79,39 @@ def main(args):
# compute loss
val_loss = criterion(logits, label)
val_acc = evaluate_acc(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
val_acc = evaluate_acc(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
validate_loss.append(val_loss)
validate_acc.append(val_acc)
validate_loss = torch.stack(validate_loss).sum().cpu().item()
validate_acc = np.mean(validate_acc)
#validate
# validate
if validate_acc > best_acc:
best_acc = validate_acc
best_epoch = epoch
torch.save(model.state_dict(), 'TAHIN'+'_'+args.dataset)
torch.save(model.state_dict(), "TAHIN" + "_" + args.dataset)
kill_cnt = 0
print("saving model...")
else:
kill_cnt += 1
if kill_cnt > args.early_stop:
print('early stop.')
print("early stop.")
print("best epoch:{}".format(best_epoch))
break
print("In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\n, Valid ACC: {:.5}".format(epoch, train_loss, validate_loss, validate_acc))
print(
"In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\n, Valid ACC: {:.5}".format(
epoch, train_loss, validate_loss, validate_acc
)
)
#test use the best model
# test use the best model
model.eval()
with torch.no_grad():
model.load_state_dict(torch.load('TAHIN'+'_'+args.dataset))
model.load_state_dict(torch.load("TAHIN" + "_" + args.dataset))
test_loss = []
test_acc = []
test_auc = []
......@@ -101,11 +123,19 @@ def main(args):
# compute loss
loss = criterion(logits, label)
acc = evaluate_acc(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
auc = evaluate_auc(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
log_loss = evaluate_logloss(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
acc = evaluate_acc(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
auc = evaluate_auc(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
f1 = evaluate_f1_score(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
log_loss = evaluate_logloss(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
test_loss.append(loss)
test_acc.append(acc)
test_auc.append(auc)
......@@ -117,33 +147,73 @@ def main(args):
test_auc = np.mean(test_auc)
test_f1 = np.mean(test_f1)
test_logloss = np.mean(test_logloss)
print("Test Loss: {:.5}\n, Test ACC: {:.5}\n, AUC: {:.5}\n, F1: {:.5}\n, Logloss: {:.5}\n".format(test_loss, test_acc, test_auc, test_f1, test_logloss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parser For Arguments', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', default='movielens', help='Dataset to use, default: movielens')
parser.add_argument('--path', default='./data', help='Path to save the data')
parser.add_argument('--model', default='TAHIN', help='Model Name')
parser.add_argument('--batch', default=128, type=int, help='Batch size')
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0')
parser.add_argument('--epochs', type=int, default=500, help='Maximum number of epochs')
parser.add_argument('--wd', type=float, default=0, help='L2 Regularization for Optimizer')
parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate')
parser.add_argument('--num_workers', type=int, default=10, help='Number of processes to construct batches')
parser.add_argument('--early_stop', default=15, type=int, help='Patience for early stop.')
parser.add_argument('--in_size', default=128, type=int, help='Initial dimension size for entities.')
parser.add_argument('--out_size', default=128, type=int, help='Output dimension size for entities.')
parser.add_argument('--num_heads', default=1, type=int, help='Number of attention heads')
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout.')
print(
"Test Loss: {:.5}\n, Test ACC: {:.5}\n, AUC: {:.5}\n, F1: {:.5}\n, Logloss: {:.5}\n".format(
test_loss, test_acc, test_auc, test_f1, test_logloss
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Parser For Arguments",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dataset",
default="movielens",
help="Dataset to use, default: movielens",
)
parser.add_argument(
"--path", default="./data", help="Path to save the data"
)
parser.add_argument("--model", default="TAHIN", help="Model Name")
parser.add_argument("--batch", default=128, type=int, help="Batch size")
parser.add_argument(
"--gpu",
type=int,
default="0",
help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
)
parser.add_argument(
"--epochs", type=int, default=500, help="Maximum number of epochs"
)
parser.add_argument(
"--wd", type=float, default=0, help="L2 Regularization for Optimizer"
)
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of processes to construct batches",
)
parser.add_argument(
"--early_stop", default=15, type=int, help="Patience for early stop."
)
parser.add_argument(
"--in_size",
default=128,
type=int,
help="Initial dimension size for entities.",
)
parser.add_argument(
"--out_size",
default=128,
type=int,
help="Output dimension size for entities.",
)
parser.add_argument(
"--num_heads", default=1, type=int, help="Number of attention heads"
)
parser.add_argument("--dropout", default=0.1, type=float, help="Dropout.")
args = parser.parse_args()
print(args)
main(args)
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, accuracy_score, log_loss, f1_score, average_precision_score, ndcg_score
from sklearn.metrics import (
accuracy_score,
average_precision_score,
f1_score,
log_loss,
ndcg_score,
roc_auc_score,
)
def evaluate_auc(pred, label):
res=roc_auc_score(y_score=pred, y_true=label)
res = roc_auc_score(y_score=pred, y_true=label)
return res
def evaluate_acc(pred, label):
res = []
for _value in pred:
......@@ -15,6 +24,7 @@ def evaluate_acc(pred, label):
res.append(0)
return accuracy_score(y_pred=res, y_true=label)
def evaluate_f1_score(pred, label):
res = []
for _value in pred:
......@@ -24,7 +34,7 @@ def evaluate_f1_score(pred, label):
res.append(0)
return f1_score(y_pred=res, y_true=label)
def evaluate_logloss(pred, label):
res = log_loss(y_true=label, y_pred=pred,eps=1e-7, normalize=True)
res = log_loss(y_true=label, y_pred=pred, eps=1e-7, normalize=True)
return res
\ No newline at end of file
import pickle
import argparse
import os
import pickle
import evaluation
import layers
import numpy as np
import sampler as sampler_module
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext
import dgl
import os
import tqdm
import layers
import sampler as sampler_module
import evaluation
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__()
self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims)
self.proj = layers.LinearProjector(
full_graph, ntype, textsets, hidden_dims
)
self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype)
......@@ -34,22 +38,23 @@ class PinSAGEModel(nn.Module):
h_item_dst = self.proj(blocks[-1].dstdata)
return h_item_dst + self.sage(blocks, h_item)
def train(dataset, args):
g = dataset['train-graph']
val_matrix = dataset['val-matrix'].tocsr()
test_matrix = dataset['test-matrix'].tocsr()
item_texts = dataset['item-texts']
user_ntype = dataset['user-type']
item_ntype = dataset['item-type']
user_to_item_etype = dataset['user-to-item-type']
timestamp = dataset['timestamp-edge-column']
g = dataset["train-graph"]
val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset["item-texts"]
user_ntype = dataset["user-type"]
item_ntype = dataset["item-type"]
user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset["timestamp-edge-column"]
device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity)
g.nodes[user_ntype].data['id'] = torch.arange(g.num_nodes(user_ntype))
g.nodes[item_ntype].data['id'] = torch.arange(g.num_nodes(item_ntype))
g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))
# Prepare torchtext dataset and Vocabulary
textset = {}
......@@ -63,30 +68,50 @@ def train(dataset, args):
l = tokenizer(item_texts[key][i].lower())
textlist.append(l)
for key, field in item_texts.items():
vocab2 = build_vocab_from_iterator(textlist, specials=["<unk>","<pad>"])
textset[key] = (textlist, vocab2, vocab2.get_stoi()['<pad>'], batch_first)
vocab2 = build_vocab_from_iterator(
textlist, specials=["<unk>", "<pad>"]
)
textset[key] = (
textlist,
vocab2,
vocab2.get_stoi()["<pad>"],
batch_first,
)
# Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size)
g, user_ntype, item_ntype, args.batch_size
)
neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors,
args.num_layers)
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset)
g,
user_ntype,
item_ntype,
args.random_walk_length,
args.random_walk_restart_prob,
args.num_random_walks,
args.num_neighbors,
args.num_layers,
)
collator = sampler_module.PinSAGECollator(
neighbor_sampler, g, item_ntype, textset
)
dataloader = DataLoader(
batch_sampler,
collate_fn=collator.collate_train,
num_workers=args.num_workers)
num_workers=args.num_workers,
)
dataloader_test = DataLoader(
torch.arange(g.num_nodes(item_ntype)),
batch_size=args.batch_size,
collate_fn=collator.collate_test,
num_workers=args.num_workers)
num_workers=args.num_workers,
)
dataloader_it = iter(dataloader)
# Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device)
model = PinSAGEModel(
g, item_ntype, textset, args.hidden_dims, args.num_layers
).to(device)
# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
......@@ -109,7 +134,9 @@ def train(dataset, args):
# Evaluate
model.eval()
with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(args.batch_size)
item_batches = torch.arange(g.num_nodes(item_ntype)).split(
args.batch_size
)
h_item_batches = []
for blocks in dataloader_test:
for i in range(len(blocks)):
......@@ -118,32 +145,37 @@ def train(dataset, args):
h_item_batches.append(model.get_repr(blocks))
h_item = torch.cat(h_item_batches, 0)
print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size))
print(
evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
)
if __name__ == '__main__':
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str)
parser.add_argument('--random-walk-length', type=int, default=2)
parser.add_argument('--random-walk-restart-prob', type=float, default=0.5)
parser.add_argument('--num-random-walks', type=int, default=10)
parser.add_argument('--num-neighbors', type=int, default=3)
parser.add_argument('--num-layers', type=int, default=2)
parser.add_argument('--hidden-dims', type=int, default=16)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0"
parser.add_argument('--num-epochs', type=int, default=1)
parser.add_argument('--batches-per-epoch', type=int, default=20000)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--lr', type=float, default=3e-5)
parser.add_argument('-k', type=int, default=10)
parser.add_argument("dataset_path", type=str)
parser.add_argument("--random-walk-length", type=int, default=2)
parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
parser.add_argument("--num-random-walks", type=int, default=10)
parser.add_argument("--num-neighbors", type=int, default=3)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--hidden-dims", type=int, default=16)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument(
"--device", type=str, default="cpu"
) # can also be "cuda:0"
parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument("--batches-per-epoch", type=int, default=20000)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("-k", type=int, default=10)
args = parser.parse_args()
# Load dataset
data_info_path = os.path.join(args.dataset_path, 'data.pkl')
with open(data_info_path, 'rb') as f:
data_info_path = os.path.join(args.dataset_path, "data.pkl")
with open(data_info_path, "rb") as f:
dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin')
train_g_path = os.path.join(args.dataset_path, "train_g.bin")
g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0]
dataset["train-graph"] = g_list[0]
train(dataset, args)
......@@ -12,23 +12,25 @@ instead put variable-length features into a more suitable container
(e.g. torchtext to handle list of texts)
"""
import os
import re
import argparse
import os
import pickle
import pandas as pd
import re
import numpy as np
import pandas as pd
import scipy.sparse as ssp
import dgl
import torch
import torchtext
from builder import PandasGraphBuilder
from data_utils import *
if __name__ == '__main__':
import dgl
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str)
parser.add_argument('out_directory', type=str)
parser.add_argument("directory", type=str)
parser.add_argument("out_directory", type=str)
args = parser.parse_args()
directory = args.directory
out_directory = args.out_directory
......@@ -38,97 +40,133 @@ if __name__ == '__main__':
# Load data
users = []
with open(os.path.join(directory, 'users.dat'), encoding='latin1') as f:
with open(os.path.join(directory, "users.dat"), encoding="latin1") as f:
for l in f:
id_, gender, age, occupation, zip_ = l.strip().split('::')
users.append({
'user_id': int(id_),
'gender': gender,
'age': age,
'occupation': occupation,
'zip': zip_,
})
users = pd.DataFrame(users).astype('category')
id_, gender, age, occupation, zip_ = l.strip().split("::")
users.append(
{
"user_id": int(id_),
"gender": gender,
"age": age,
"occupation": occupation,
"zip": zip_,
}
)
users = pd.DataFrame(users).astype("category")
movies = []
with open(os.path.join(directory, 'movies.dat'), encoding='latin1') as f:
with open(os.path.join(directory, "movies.dat"), encoding="latin1") as f:
for l in f:
id_, title, genres = l.strip().split('::')
genres_set = set(genres.split('|'))
id_, title, genres = l.strip().split("::")
genres_set = set(genres.split("|"))
# extract year
assert re.match(r'.*\([0-9]{4}\)$', title)
assert re.match(r".*\([0-9]{4}\)$", title)
year = title[-5:-1]
title = title[:-6].strip()
data = {'movie_id': int(id_), 'title': title, 'year': year}
data = {"movie_id": int(id_), "title": title, "year": year}
for g in genres_set:
data[g] = True
movies.append(data)
movies = pd.DataFrame(movies).astype({'year': 'category'})
movies = pd.DataFrame(movies).astype({"year": "category"})
ratings = []
with open(os.path.join(directory, 'ratings.dat'), encoding='latin1') as f:
with open(os.path.join(directory, "ratings.dat"), encoding="latin1") as f:
for l in f:
user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')]
ratings.append({
'user_id': user_id,
'movie_id': movie_id,
'rating': rating,
'timestamp': timestamp,
})
user_id, movie_id, rating, timestamp = [
int(_) for _ in l.split("::")
]
ratings.append(
{
"user_id": user_id,
"movie_id": movie_id,
"rating": rating,
"timestamp": timestamp,
}
)
ratings = pd.DataFrame(ratings)
# Filter the users and items that never appear in the rating table.
distinct_users_in_ratings = ratings['user_id'].unique()
distinct_movies_in_ratings = ratings['movie_id'].unique()
users = users[users['user_id'].isin(distinct_users_in_ratings)]
movies = movies[movies['movie_id'].isin(distinct_movies_in_ratings)]
distinct_users_in_ratings = ratings["user_id"].unique()
distinct_movies_in_ratings = ratings["movie_id"].unique()
users = users[users["user_id"].isin(distinct_users_in_ratings)]
movies = movies[movies["movie_id"].isin(distinct_movies_in_ratings)]
# Group the movie features into genres (a vector), year (a category), title (a string)
genre_columns = movies.columns.drop(['movie_id', 'title', 'year'])
movies[genre_columns] = movies[genre_columns].fillna(False).astype('bool')
movies_categorical = movies.drop('title', axis=1)
genre_columns = movies.columns.drop(["movie_id", "title", "year"])
movies[genre_columns] = movies[genre_columns].fillna(False).astype("bool")
movies_categorical = movies.drop("title", axis=1)
# Build graph
graph_builder = PandasGraphBuilder()
graph_builder.add_entities(users, 'user_id', 'user')
graph_builder.add_entities(movies_categorical, 'movie_id', 'movie')
graph_builder.add_binary_relations(ratings, 'user_id', 'movie_id', 'watched')
graph_builder.add_binary_relations(ratings, 'movie_id', 'user_id', 'watched-by')
graph_builder.add_entities(users, "user_id", "user")
graph_builder.add_entities(movies_categorical, "movie_id", "movie")
graph_builder.add_binary_relations(
ratings, "user_id", "movie_id", "watched"
)
graph_builder.add_binary_relations(
ratings, "movie_id", "user_id", "watched-by"
)
g = graph_builder.build()
# Assign features.
# Note that variable-sized features such as texts or images are handled elsewhere.
g.nodes['user'].data['gender'] = torch.LongTensor(users['gender'].cat.codes.values)
g.nodes['user'].data['age'] = torch.LongTensor(users['age'].cat.codes.values)
g.nodes['user'].data['occupation'] = torch.LongTensor(users['occupation'].cat.codes.values)
g.nodes['user'].data['zip'] = torch.LongTensor(users['zip'].cat.codes.values)
g.nodes['movie'].data['year'] = torch.LongTensor(movies['year'].cat.codes.values)
g.nodes['movie'].data['genre'] = torch.FloatTensor(movies[genre_columns].values)
g.edges['watched'].data['rating'] = torch.LongTensor(ratings['rating'].values)
g.edges['watched'].data['timestamp'] = torch.LongTensor(ratings['timestamp'].values)
g.edges['watched-by'].data['rating'] = torch.LongTensor(ratings['rating'].values)
g.edges['watched-by'].data['timestamp'] = torch.LongTensor(ratings['timestamp'].values)
g.nodes["user"].data["gender"] = torch.LongTensor(
users["gender"].cat.codes.values
)
g.nodes["user"].data["age"] = torch.LongTensor(
users["age"].cat.codes.values
)
g.nodes["user"].data["occupation"] = torch.LongTensor(
users["occupation"].cat.codes.values
)
g.nodes["user"].data["zip"] = torch.LongTensor(
users["zip"].cat.codes.values
)
g.nodes["movie"].data["year"] = torch.LongTensor(
movies["year"].cat.codes.values
)
g.nodes["movie"].data["genre"] = torch.FloatTensor(
movies[genre_columns].values
)
g.edges["watched"].data["rating"] = torch.LongTensor(
ratings["rating"].values
)
g.edges["watched"].data["timestamp"] = torch.LongTensor(
ratings["timestamp"].values
)
g.edges["watched-by"].data["rating"] = torch.LongTensor(
ratings["rating"].values
)
g.edges["watched-by"].data["timestamp"] = torch.LongTensor(
ratings["timestamp"].values
)
# Train-validation-test split
# This is a little bit tricky as we want to select the last interaction for test, and the
# second-to-last interaction for validation.
train_indices, val_indices, test_indices = train_test_split_by_time(ratings, 'timestamp', 'user_id')
train_indices, val_indices, test_indices = train_test_split_by_time(
ratings, "timestamp", "user_id"
)
# Build the graph with training interactions only.
train_g = build_train_graph(g, train_indices, 'user', 'movie', 'watched', 'watched-by')
assert train_g.out_degrees(etype='watched').min() > 0
train_g = build_train_graph(
g, train_indices, "user", "movie", "watched", "watched-by"
)
assert train_g.out_degrees(etype="watched").min() > 0
# Build the user-item sparse matrix for validation and test set.
val_matrix, test_matrix = build_val_test_matrix(g, val_indices, test_indices, 'user', 'movie', 'watched')
val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, "user", "movie", "watched"
)
## Build title set
movie_textual_dataset = {'title': movies['title'].values}
movie_textual_dataset = {"title": movies["title"].values}
# The model should build their own vocabulary and process the texts. Here is one example
# of using torchtext to pad and numericalize a batch of strings.
......@@ -140,18 +178,19 @@ if __name__ == '__main__':
## Dump the graph and the datasets
dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g)
dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
dataset = {
'val-matrix': val_matrix,
'test-matrix': test_matrix,
'item-texts': movie_textual_dataset,
'item-images': None,
'user-type': 'user',
'item-type': 'movie',
'user-to-item-type': 'watched',
'item-to-user-type': 'watched-by',
'timestamp-edge-column': 'timestamp'}
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f:
"val-matrix": val_matrix,
"test-matrix": test_matrix,
"item-texts": movie_textual_dataset,
"item-images": None,
"user-type": "user",
"item-type": "movie",
"user-to-item-type": "watched",
"item-to-user-type": "watched-by",
"timestamp-edge-column": "timestamp",
}
with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
pickle.dump(dataset, f)
......@@ -3,76 +3,102 @@ Script that reads from raw Nowplaying-RS data and dumps into a pickle
file a heterogeneous graph with categorical and numeric features.
"""
import os
import argparse
import dgl
import os
import pickle
import pandas as pd
import scipy.sparse as ssp
import pickle
from data_utils import *
from builder import PandasGraphBuilder
from data_utils import *
import dgl
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str)
parser.add_argument('out_directory', type=str)
parser.add_argument("directory", type=str)
parser.add_argument("out_directory", type=str)
args = parser.parse_args()
directory = args.directory
out_directory = args.out_directory
os.makedirs(out_directory, exist_ok=True)
data = pd.read_csv(os.path.join(directory, 'context_content_features.csv'))
data = pd.read_csv(os.path.join(directory, "context_content_features.csv"))
track_feature_cols = list(data.columns[1:13])
data = data[['user_id', 'track_id', 'created_at'] + track_feature_cols].dropna()
data = data[
["user_id", "track_id", "created_at"] + track_feature_cols
].dropna()
users = data[['user_id']].drop_duplicates()
tracks = data[['track_id'] + track_feature_cols].drop_duplicates()
assert tracks['track_id'].value_counts().max() == 1
tracks = tracks.astype({'mode': 'int64', 'key': 'int64', 'artist_id': 'category'})
events = data[['user_id', 'track_id', 'created_at']]
events['created_at'] = events['created_at'].values.astype('datetime64[s]').astype('int64')
users = data[["user_id"]].drop_duplicates()
tracks = data[["track_id"] + track_feature_cols].drop_duplicates()
assert tracks["track_id"].value_counts().max() == 1
tracks = tracks.astype(
{"mode": "int64", "key": "int64", "artist_id": "category"}
)
events = data[["user_id", "track_id", "created_at"]]
events["created_at"] = (
events["created_at"].values.astype("datetime64[s]").astype("int64")
)
graph_builder = PandasGraphBuilder()
graph_builder.add_entities(users, 'user_id', 'user')
graph_builder.add_entities(tracks, 'track_id', 'track')
graph_builder.add_binary_relations(events, 'user_id', 'track_id', 'listened')
graph_builder.add_binary_relations(events, 'track_id', 'user_id', 'listened-by')
graph_builder.add_entities(users, "user_id", "user")
graph_builder.add_entities(tracks, "track_id", "track")
graph_builder.add_binary_relations(
events, "user_id", "track_id", "listened"
)
graph_builder.add_binary_relations(
events, "track_id", "user_id", "listened-by"
)
g = graph_builder.build()
float_cols = []
for col in tracks.columns:
if col == 'track_id':
if col == "track_id":
continue
elif col == 'artist_id':
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].cat.codes.values)
elif tracks.dtypes[col] == 'float64':
elif col == "artist_id":
g.nodes["track"].data[col] = torch.LongTensor(
tracks[col].cat.codes.values
)
elif tracks.dtypes[col] == "float64":
float_cols.append(col)
else:
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].values)
g.nodes['track'].data['song_features'] = torch.FloatTensor(linear_normalize(tracks[float_cols].values))
g.edges['listened'].data['created_at'] = torch.LongTensor(events['created_at'].values)
g.edges['listened-by'].data['created_at'] = torch.LongTensor(events['created_at'].values)
g.nodes["track"].data[col] = torch.LongTensor(tracks[col].values)
g.nodes["track"].data["song_features"] = torch.FloatTensor(
linear_normalize(tracks[float_cols].values)
)
g.edges["listened"].data["created_at"] = torch.LongTensor(
events["created_at"].values
)
g.edges["listened-by"].data["created_at"] = torch.LongTensor(
events["created_at"].values
)
n_edges = g.num_edges('listened')
train_indices, val_indices, test_indices = train_test_split_by_time(events, 'created_at', 'user_id')
train_g = build_train_graph(g, train_indices, 'user', 'track', 'listened', 'listened-by')
assert train_g.out_degrees(etype='listened').min() > 0
n_edges = g.num_edges("listened")
train_indices, val_indices, test_indices = train_test_split_by_time(
events, "created_at", "user_id"
)
train_g = build_train_graph(
g, train_indices, "user", "track", "listened", "listened-by"
)
assert train_g.out_degrees(etype="listened").min() > 0
val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, 'user', 'track', 'listened')
g, val_indices, test_indices, "user", "track", "listened"
)
dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g)
dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
dataset = {
'val-matrix': val_matrix,
'test-matrix': test_matrix,
'item-texts': {},
'item-images': None,
'user-type': 'user',
'item-type': 'track',
'user-to-item-type': 'listened',
'item-to-user-type': 'listened-by',
'timestamp-edge-column': 'created_at'}
"val-matrix": val_matrix,
"test-matrix": test_matrix,
"item-texts": {},
"item-images": None,
"user-type": "user",
"item-type": "track",
"user-to-item-type": "listened",
"item-to-user-type": "listened-by",
"timestamp-edge-column": "created_at",
}
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f:
with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
pickle.dump(dataset, f)
import numpy as np
import dgl
import torch
from torch.utils.data import IterableDataset, DataLoader
from torch.utils.data import DataLoader, IterableDataset
from torchtext.data.functional import numericalize_tokens_from_iterator
import dgl
def padding(array, yy, val):
"""
:param array: torch tensor array
......@@ -15,7 +17,10 @@ def padding(array, yy, val):
b = 0
bb = yy - b - w
return torch.nn.functional.pad(array, pad=(b, bb), mode='constant', value=val)
return torch.nn.functional.pad(
array, pad=(b, bb), mode="constant", value=val
)
def compact_and_copy(frontier, seeds):
block = dgl.to_block(frontier, seeds)
......@@ -25,6 +30,7 @@ def compact_and_copy(frontier, seeds):
block.edata[col] = data[block.edata[dgl.EID]]
return block
class ItemToItemBatchSampler(IterableDataset):
def __init__(self, g, user_type, item_type, batch_size):
self.g = g
......@@ -36,42 +42,69 @@ class ItemToItemBatchSampler(IterableDataset):
def __iter__(self):
while True:
heads = torch.randint(0, self.g.num_nodes(self.item_type), (self.batch_size,))
heads = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
tails = dgl.sampling.random_walk(
self.g,
heads,
metapath=[self.item_to_user_etype, self.user_to_item_etype])[0][:, 2]
neg_tails = torch.randint(0, self.g.num_nodes(self.item_type), (self.batch_size,))
metapath=[self.item_to_user_etype, self.user_to_item_etype],
)[0][:, 2]
neg_tails = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
mask = (tails != -1)
mask = tails != -1
yield heads[mask], tails[mask], neg_tails[mask]
class NeighborSampler(object):
def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob,
num_random_walks, num_neighbors, num_layers):
def __init__(
self,
g,
user_type,
item_type,
random_walk_length,
random_walk_restart_prob,
num_random_walks,
num_neighbors,
num_layers,
):
self.g = g
self.user_type = user_type
self.item_type = item_type
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.samplers = [
dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
random_walk_restart_prob, num_random_walks, num_neighbors)
for _ in range(num_layers)]
dgl.sampling.PinSAGESampler(
g,
item_type,
user_type,
random_walk_length,
random_walk_restart_prob,
num_random_walks,
num_neighbors,
)
for _ in range(num_layers)
]
def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):
blocks = []
for sampler in self.samplers:
frontier = sampler(seeds)
if heads is not None:
eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2]
eids = frontier.edge_ids(
torch.cat([heads, heads]),
torch.cat([tails, neg_tails]),
return_uv=True,
)[2]
if len(eids) > 0:
old_frontier = frontier
frontier = dgl.remove_edges(old_frontier, eids)
#print(old_frontier)
#print(frontier)
#print(frontier.edata['weights'])
#frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
# print(old_frontier)
# print(frontier)
# print(frontier.edata['weights'])
# frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
block = compact_and_copy(frontier, seeds)
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
......@@ -81,17 +114,18 @@ class NeighborSampler(object):
# Create a graph with positive connections only and another graph with negative
# connections only.
pos_graph = dgl.graph(
(heads, tails),
num_nodes=self.g.num_nodes(self.item_type))
(heads, tails), num_nodes=self.g.num_nodes(self.item_type)
)
neg_graph = dgl.graph(
(heads, neg_tails),
num_nodes=self.g.num_nodes(self.item_type))
(heads, neg_tails), num_nodes=self.g.num_nodes(self.item_type)
)
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID]
blocks = self.sample_blocks(seeds, heads, tails, neg_tails)
return pos_graph, neg_graph, blocks
def assign_simple_node_features(ndata, g, ntype, assign_id=False):
"""
Copies data to the given block from the corresponding nodes in the original graph.
......@@ -102,6 +136,7 @@ def assign_simple_node_features(ndata, g, ntype, assign_id=False):
induced_nodes = ndata[dgl.NID]
ndata[col] = g.nodes[ntype].data[col][induced_nodes]
def assign_textual_node_features(ndata, textset, ntype):
"""
Assigns numericalized tokens from a torchtext dataset to given block.
......@@ -127,10 +162,10 @@ def assign_textual_node_features(ndata, textset, ntype):
for field_name, field in textset.items():
textlist, vocab, pad_var, batch_first = field
examples = [textlist[i] for i in node_ids]
ids_iter = numericalize_tokens_from_iterator(vocab, examples)
maxsize = max([len(textlist[i]) for i in node_ids])
ids = next(ids_iter)
x = torch.asarray([num for num in ids])
......@@ -141,14 +176,15 @@ def assign_textual_node_features(ndata, textset, ntype):
x = torch.asarray([num for num in ids])
l = torch.tensor([len(x)])
y = padding(x, maxsize, pad_var)
tokens = torch.vstack((tokens,y))
tokens = torch.vstack((tokens, y))
lengths = torch.cat((lengths, l))
if not batch_first:
tokens = tokens.t()
ndata[field_name] = tokens
ndata[field_name + '__len'] = lengths
ndata[field_name + "__len"] = lengths
def assign_features_to_blocks(blocks, g, textset, ntype):
# For the first block (which is closest to the input), copy the features from
......@@ -158,6 +194,7 @@ def assign_features_to_blocks(blocks, g, textset, ntype):
assign_simple_node_features(blocks[-1].dstdata, g, ntype)
assign_textual_node_features(blocks[-1].dstdata, textset, ntype)
class PinSAGECollator(object):
def __init__(self, sampler, g, ntype, textset):
self.sampler = sampler
......@@ -168,7 +205,9 @@ class PinSAGECollator(object):
def collate_train(self, batches):
heads, tails, neg_tails = batches[0]
# Construct multilayer neighborhood via PinSAGE...
pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(heads, tails, neg_tails)
pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(
heads, tails, neg_tails
)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
return pos_graph, neg_graph, blocks
......
import argparse
import os
import urllib
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from modelnet import ModelNet
import tqdm
from model import Model, compute_loss
from dgl.data.utils import download, get_download_dir
from modelnet import ModelNet
from torch.utils.data import DataLoader
from functools import partial
import tqdm
import urllib
import os
import argparse
from dgl.data.utils import download, get_download_dir
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=100)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=100)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40-sampled-2048.h5'
local_path = args.dataset_path or os.path.join(get_download_dir(), data_filename)
data_filename = "modelnet40-sampled-2048.h5"
local_path = args.dataset_path or os.path.join(
get_download_dir(), data_filename
)
if not os.path.exists(local_path):
download('https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5', local_path)
download(
"https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5", local_path
)
CustomDataLoader = partial(
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
DataLoader,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
def train(model, opt, scheduler, train_loader, dev):
scheduler.step()
......@@ -65,11 +72,15 @@ def train(model, opt, scheduler, train_loader, dev):
total_loss += loss
total_correct += correct
tq.set_postfix({
'Loss': '%.5f' % loss,
'AvgLoss': '%.5f' % (total_loss / num_batches),
'Acc': '%.5f' % (correct / num_examples),
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix(
{
"Loss": "%.5f" % loss,
"AvgLoss": "%.5f" % (total_loss / num_batches),
"Acc": "%.5f" % (correct / num_examples),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
def evaluate(model, test_loader, dev):
model.eval()
......@@ -89,9 +100,12 @@ def evaluate(model, test_loader, dev):
total_correct += correct
count += num_examples
tq.set_postfix({
'Acc': '%.5f' % (correct / num_examples),
'AvgAcc': '%.5f' % (total_correct / count)})
tq.set_postfix(
{
"Acc": "%.5f" % (correct / num_examples),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
return total_correct / count
......@@ -105,7 +119,9 @@ if args.load_model_path:
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, args.num_epochs, eta_min=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
opt, args.num_epochs, eta_min=0.001
)
modelnet = ModelNet(local_path, 1024)
......@@ -117,7 +133,7 @@ best_valid_acc = 0
best_test_acc = 0
for epoch in range(args.num_epochs):
print('Epoch #%d Validating' % epoch)
print("Epoch #%d Validating" % epoch)
valid_acc = evaluate(model, valid_loader, dev)
test_acc = evaluate(model, test_loader, dev)
if valid_acc > best_valid_acc:
......@@ -125,7 +141,9 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc
if args.save_model_path:
torch.save(model.state_dict(), args.save_model_path)
print('Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)' % (
valid_acc, best_valid_acc, test_acc, best_test_acc))
print(
"Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)"
% (valid_acc, best_valid_acc, test_acc, best_test_acc)
)
train(model, opt, scheduler, train_loader, dev)
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import KNNGraph, EdgeConv
from dgl.nn.pytorch import EdgeConv, KNNGraph
class Model(nn.Module):
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3,
dropout_prob=0.5):
def __init__(
self,
k,
feature_dims,
emb_dims,
output_classes,
input_dims=3,
dropout_prob=0.5,
):
super(Model, self).__init__()
self.nng = KNNGraph(k)
......@@ -13,10 +22,13 @@ class Model(nn.Module):
self.num_layers = len(feature_dims)
for i in range(self.num_layers):
self.conv.append(EdgeConv(
feature_dims[i - 1] if i > 0 else input_dims,
feature_dims[i],
batch_norm=True))
self.conv.append(
EdgeConv(
feature_dims[i - 1] if i > 0 else input_dims,
feature_dims[i],
batch_norm=True,
)
)
self.proj = nn.Linear(sum(feature_dims), emb_dims[0])
......@@ -26,10 +38,13 @@ class Model(nn.Module):
self.num_embs = len(emb_dims) - 1
for i in range(1, self.num_embs + 1):
self.embs.append(nn.Linear(
# * 2 because of concatenation of max- and mean-pooling
emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2),
emb_dims[i]))
self.embs.append(
nn.Linear(
# * 2 because of concatenation of max- and mean-pooling
emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2),
emb_dims[i],
)
)
self.bn_embs.append(nn.BatchNorm1d(emb_dims[i]))
self.dropouts.append(nn.Dropout(dropout_prob))
......
import numpy as np
from torch.utils.data import Dataset
class ModelNet(object):
def __init__(self, path, num_points):
import h5py
self.f = h5py.File(path)
self.num_points = num_points
self.n_train = self.f['train/data'].shape[0]
self.n_train = self.f["train/data"].shape[0]
self.n_valid = int(self.n_train / 5)
self.n_train -= self.n_valid
self.n_test = self.f['test/data'].shape[0]
self.n_test = self.f["test/data"].shape[0]
def train(self):
return ModelNetDataset(self, 'train')
return ModelNetDataset(self, "train")
def valid(self):
return ModelNetDataset(self, 'valid')
return ModelNetDataset(self, "valid")
def test(self):
return ModelNetDataset(self, 'test')
return ModelNetDataset(self, "test")
class ModelNetDataset(Dataset):
def __init__(self, modelnet, mode):
......@@ -27,29 +30,29 @@ class ModelNetDataset(Dataset):
self.num_points = modelnet.num_points
self.mode = mode
if mode == 'train':
self.data = modelnet.f['train/data'][:modelnet.n_train]
self.label = modelnet.f['train/label'][:modelnet.n_train]
elif mode == 'valid':
self.data = modelnet.f['train/data'][modelnet.n_train:]
self.label = modelnet.f['train/label'][modelnet.n_train:]
elif mode == 'test':
self.data = modelnet.f['test/data'].value
self.label = modelnet.f['test/label'].value
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2)):
if mode == "train":
self.data = modelnet.f["train/data"][: modelnet.n_train]
self.label = modelnet.f["train/label"][: modelnet.n_train]
elif mode == "valid":
self.data = modelnet.f["train/data"][modelnet.n_train :]
self.label = modelnet.f["train/label"][modelnet.n_train :]
elif mode == "test":
self.data = modelnet.f["test/data"].value
self.label = modelnet.f["test/label"].value
def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2)):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x
def __len__(self):
return self.data.shape[0]
def __getitem__(self, i):
x = self.data[i][:self.num_points]
x = self.data[i][: self.num_points]
y = self.label[i]
if self.mode == 'train':
if self.mode == "train":
x = self.translate(x)
np.random.shuffle(x)
return x, y
import numpy as np
import warnings
import os
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc):
......@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
def __init__(
self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
"""
Input:
root: the root path to the local data files
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
......@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
shape_ids["train"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
self.datapath = [
(
shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
......@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import os, json, tqdm
import numpy as np
import dgl
import json
import os
from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
......@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
data_filename = (
"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
......@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
self.seg_classes = {
"Airplane": [0, 1, 2, 3],
"Bag": [4, 5],
"Cap": [6, 7],
"Car": [8, 9, 10, 11],
"Chair": [12, 13, 14, 15],
"Earphone": [16, 17, 18],
"Guitar": [19, 20, 21],
"Knife": [22, 23],
"Lamp": [24, 25, 26, 27],
"Laptop": [28, 29],
"Motorbike": [30, 31, 32, 33, 34, 35],
"Mug": [36, 37],
"Pistol": [38, 39, 40],
"Rocket": [41, 42, 43],
"Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
}
train_split_json = "shuffled_train_file_list.json"
val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
......@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else:
self.dim = 6
if mode == 'train':
if mode == "train":
self.file_list = shapenet.train_file_list
elif mode == 'valid':
elif mode == "valid":
self.file_list = shapenet.val_file_list
elif mode == 'test':
elif mode == "test":
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
......@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
data = np.array(
[t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
inds = np.random.choice(
self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
if self.mode == "train":
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.geometry import farthest_point_sampler
'''
"""
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
def square_distance(src, dst):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(
device).view(view_shape).repeat(repeat_shape)
batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
class KNearNeighbors(nn.Module):
'''
"""
Find the k nearest neighbors
'''
"""
def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
'''
"""
center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor]
group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx
class KNNGraphBuilder(nn.Module):
'''
"""
Build NN graph
'''
"""
def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__()
......@@ -76,59 +81,68 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor,
src.shape[0] // centroids.shape[1])).view(-1)
dst = (
centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]]
dst_idx = inv_idx[src.shape[0]:]
src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq]
g.ndata['center'] = center[uniq]
g.ndata["pos"] = pos[i][uniq]
g.ndata["center"] = center[uniq]
if feat is not None:
g.ndata['feat'] = feat[i][uniq]
g.ndata["feat"] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class KNNMessage(nn.Module):
'''
"""
Compute the input feature from neighbors
'''
"""
def __init__(self, n_neighbor):
super(KNNMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
norm = edges.src['feat'] - edges.dst['feat']
if 'feat' in edges.src:
res = torch.cat([norm, edges.src['feat']], 1)
norm = edges.src["feat"] - edges.dst["feat"]
if "feat" in edges.src:
res = torch.cat([norm, edges.src["feat"]], 1)
else:
res = norm
return {'agg_feat': res}
return {"agg_feat": res}
class KNNConv(nn.Module):
'''
"""
Feature aggregation
'''
"""
def __init__(self, sizes):
super(KNNConv, self).__init__()
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1))
self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape
h = nodes.mailbox['agg_feat'].view(
shape[0], -1, shape[1], shape[2]).permute(0, 3, 2, 1)
shape = nodes.mailbox["agg_feat"].shape
h = (
nodes.mailbox["agg_feat"]
.view(shape[0], -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
......@@ -136,7 +150,7 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h}
return {"new_feat": h}
class TransitionDown(nn.Module):
......@@ -156,10 +170,9 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1
pos_dim = g.ndata['pos'].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1]
pos_res = g.ndata['pos'][mask].view(batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view(
batch_size, -1, feat_dim)
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata["pos"][mask].view(batch_size, -1, pos_dim)
feat_res = g.ndata["new_feat"][mask].view(batch_size, -1, feat_dim)
return pos_res, feat_res
import torch
from torch import nn
from helper import TransitionDown
from torch import nn
'''
"""
Part of the code are adapted from
https://github.com/MenghaoGuo/PCT
'''
"""
class PCTPositionEmbedding(nn.Module):
def __init__(self, channels=256):
......@@ -98,15 +98,21 @@ class PointTransformerCLS(nn.Module):
self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64)
self.g_op0 = TransitionDown(in_channels=128, out_channels=128, n_neighbor=32)
self.g_op1 = TransitionDown(in_channels=256, out_channels=256, n_neighbor=32)
self.g_op0 = TransitionDown(
in_channels=128, out_channels=128, n_neighbor=32
)
self.g_op1 = TransitionDown(
in_channels=256, out_channels=256, n_neighbor=32
)
self.pt_last = PCTPositionEmbedding()
self.relu = nn.ReLU()
self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2))
self.conv_fuse = nn.Sequential(
nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2),
)
self.linear1 = nn.Linear(1024, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
......@@ -159,13 +165,17 @@ class PointTransformerSeg(nn.Module):
self.sa3 = SALayerSeg(128)
self.sa4 = SALayerSeg(128)
self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2))
self.conv_fuse = nn.Sequential(
nn.Conv1d(512, 1024, kernel_size=1, bias=False),
nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2),
)
self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
nn.BatchNorm1d(64),
nn.LeakyReLU(negative_slope=0.2))
self.label_conv = nn.Sequential(
nn.Conv1d(16, 64, kernel_size=1, bias=False),
nn.BatchNorm1d(64),
nn.LeakyReLU(negative_slope=0.2),
)
self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)
self.dp1 = nn.Dropout(0.5)
......@@ -189,13 +199,12 @@ class PointTransformerSeg(nn.Module):
x = self.conv_fuse(x)
x_max, _ = torch.max(x, 2)
x_avg = torch.mean(x, 2)
x_max_feature = x_max.view(
batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
x_avg_feature = x_avg.view(
batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
x_max_feature = x_max.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
x_avg_feature = x_avg.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
cls_label_feature = self.label_conv(cls_label).repeat(1, 1, N)
x_global_feature = torch.cat(
(x_max_feature, x_avg_feature, cls_label_feature), 1)
(x_max_feature, x_avg_feature, cls_label_feature), 1
)
x = torch.cat((x, x_global_feature), 1)
x = self.relu(self.bns1(self.convs1(x)))
x = self.dp1(x)
......
'''
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
'''
"""
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
......@@ -16,237 +17,296 @@ def normalize_data(batch_data):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
return batch_data[:, idx, :]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
rotation_matrix = np.array(
[[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
)
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
"""Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
"""
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
def rotate_perturbation_point_cloud_with_normal(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
# rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
rotation_matrix = np.array(
[[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
)
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
def rotate_perturbation_point_cloud(
batch_data, angle_sigma=0.06, angle_clip=0.18
):
"""Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
angles = np.clip(
angle_sigma * np.random.randn(3), -angle_clip, angle_clip
)
Rx = np.array(
[
[1, 0, 0],
[0, np.cos(angles[0]), -np.sin(angles[0])],
[0, np.sin(angles[0]), np.cos(angles[0])],
]
)
Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
"""batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(
np.random.random((batch_pc.shape[1])) <= dropout_ratio
)[0]
if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc
from pct import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse
import os
import tqdm
import time
from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader
import torch.nn as nn
import provider
import torch
import time
import torch.nn as nn
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from pct import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
num_workers = args.num_workers
batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip'
data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled')
get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip',
download_path, verify_ssl=False)
download(
"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile
with ZipFile(download_path) as z:
z.extractall(path=get_download_dir())
......@@ -42,7 +49,8 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
......@@ -59,8 +67,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq:
data = data.data.numpy()
data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud(
data[:, :, 0:3])
data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data)
......@@ -83,11 +90,19 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss
total_correct += correct
tq.set_postfix({
'AvgLoss': '%.5f' % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, time.time() - start_time))
tq.set_postfix(
{
"AvgLoss": "%.5f" % (total_loss / num_batches),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches,
total_correct / count,
time.time() - start_time,
)
)
scheduler.step()
......@@ -110,10 +125,12 @@ def evaluate(net, test_loader, dev):
total_correct += correct
count += num_examples
tq.set_postfix({
'AvgAcc': '%.5f' % (total_correct / count)})
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time))
tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
print(
"[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time
)
)
return total_correct / count
......@@ -126,21 +143,29 @@ if args.load_model_path:
opt = torch.optim.SGD(
net.parameters(),
lr=0.01,
weight_decay=1e-4,
momentum=0.9
net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs)
opt, T_max=args.num_epochs
)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train')
test_dataset = ModelNetDataLoader(local_path, 1024, split='test')
train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True)
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0
......@@ -153,6 +178,5 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % (
test_acc, best_test_acc))
print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse
import time
from functools import partial
import numpy as np
import provider
import torch
import torch.optim as optim
import tqdm
from pct import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet
from pct import PointTransformerSeg, PartSegLoss
from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='')
parser.add_argument('--load-model-path', type=str, default='')
parser.add_argument('--save-model-path', type=str, default='')
parser.add_argument('--num-epochs', type=int, default=500)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--tensorboard', action='store_true')
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument("--num-epochs", type=int, default=500)
parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--tensorboard", action="store_true")
args = parser.parse_args()
num_workers = args.num_workers
......@@ -37,7 +37,8 @@ CustomDataLoader = partial(
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev):
......@@ -59,7 +60,8 @@ def train(net, opt, scheduler, train_loader, dev):
cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float)
dev, dtype=torch.float
)
cat_tensor = cat_tensor.view(num_examples, 16, 1)
logits = net(data, cat_tensor)
loss = L(logits, label)
......@@ -78,14 +80,17 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count
tq.set_postfix({
'AvgLoss': '%.5f' % AvgLoss,
'AvgAcc': '%.5f' % AvgAcc})
tq.set_postfix(
{"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
)
scheduler.step()
end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss /
num_batches, total_correct / count, end - start))
return data, preds, AvgLoss, AvgAcc, end-start
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches, total_correct / count, end - start
)
)
return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes):
......@@ -129,26 +134,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float)
cat_tensor = cat_tensor.view(
num_examples, 16, 1)
dev, dtype=torch.float
)
cat_tensor = cat_tensor.view(num_examples, 16, 1)
logits = net(data, cat_tensor)
_, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat, cat_miou, shapenet.seg_classes)
cat_miou = mIoU(
preds.cpu().numpy(),
label.view(num_examples, -1).cpu().numpy(),
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items():
if v[1] > 0:
miou += v[0]
count += v[1]
per_cat_miou += v[0] / v[1]
per_cat_count += 1
tq.set_postfix({
'mIoU': '%.5f' % (miou / count),
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)})
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" %
(miou / count, per_cat_miou / per_cat_count))
tq.set_postfix(
{
"mIoU": "%.5f" % (miou / count),
"per Category mIoU": "%.5f"
% (per_cat_miou / per_cat_count),
}
)
print(
"[Test] mIoU: %.5f, per Category mIoU: %.5f"
% (miou / count, per_cat_miou / per_cat_count)
)
if per_cat_verbose:
print("-" * 60)
print("Per-Category mIoU:")
......@@ -169,14 +184,12 @@ if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = torch.optim.SGD(
net.parameters(),
lr=0.01,
weight_decay=1e-4,
momentum=0.9
net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs)
opt, T_max=args.num_epochs
)
L = PartSegLoss()
......@@ -190,20 +203,63 @@ if args.tensorboard:
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter()
# Select 50 distinct colors for different parts
color_map = torch.tensor([
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180],
])
color_map = torch.tensor(
[
[47, 79, 79],
[139, 69, 19],
[112, 128, 144],
[85, 107, 47],
[139, 0, 0],
[128, 128, 0],
[72, 61, 139],
[0, 128, 0],
[188, 143, 143],
[60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred
......@@ -219,28 +275,38 @@ best_test_per_cat_miou = 0
for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev)
net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate(
net, test_loader, dev, True)
test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)
if test_miou > best_test_miou:
best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % (
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou))
print(
"Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard
if args.tensorboard:
colored = paint(preds)
writer.add_mesh('data', vertices=data,
colors=colored, global_step=epoch)
writer.add_scalar('training time for one epoch',
training_time, global_step=epoch)
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch)
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch)
writer.add_mesh(
"data", vertices=data, colors=colored, global_step=epoch
)
writer.add_scalar(
"training time for one epoch", training_time, global_step=epoch
)
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch)
writer.add_scalar('best test mIoU',
best_test_miou, global_step=epoch)
writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar(
"best test mIoU", best_test_miou, global_step=epoch
)
print()
import numpy as np
import warnings
import os
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc):
......@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False,
normal_channel=True, cache_size=15000):
def __init__(
self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
"""
Input:
root: the root path to the local data files
root: the root path to the local data files
npoint: number of points from each cloud
split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler
......@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root
self.npoints = npoint
self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
shape_ids["train"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_train.txt"))
]
shape_ids["test"] = [
line.rstrip()
for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
self.datapath = [
(
shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size
self.cache = {}
......@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import os, json, tqdm
import numpy as np
import dgl
import json
import os
from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir
class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points
......@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
data_filename = (
"shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path):
......@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {}
for syn in synset:
self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3],
'Bag': [4, 5],
'Cap': [6, 7],
'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15],
'Earphone': [16, 17, 18],
'Guitar': [19, 20, 21],
'Knife': [22, 23],
'Lamp': [24, 25, 26, 27],
'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35],
'Mug': [36, 37],
'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43],
'Skateboard': [44, 45, 46],
'Table': [47, 48, 49]}
train_split_json = 'shuffled_train_file_list.json'
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json'
split_path = os.path.join(data_path, 'train_test_split')
self.seg_classes = {
"Airplane": [0, 1, 2, 3],
"Bag": [4, 5],
"Cap": [6, 7],
"Car": [8, 9, 10, 11],
"Chair": [12, 13, 14, 15],
"Earphone": [16, 17, 18],
"Guitar": [19, 20, 21],
"Knife": [22, 23],
"Lamp": [24, 25, 26, 27],
"Laptop": [28, 29],
"Motorbike": [30, 31, 32, 33, 34, 35],
"Mug": [36, 37],
"Pistol": [38, 39, 40],
"Rocket": [41, 42, 43],
"Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
}
train_split_json = "shuffled_train_file_list.json"
val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)
return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True):
......@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else:
self.dim = 6
if mode == 'train':
if mode == "train":
self.file_list = shapenet.train_file_list
elif mode == 'valid':
elif mode == "valid":
self.file_list = shapenet.val_file_list
elif mode == 'test':
elif mode == "test":
self.file_list = shapenet.test_file_list
elif mode == 'trainval':
elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list
else:
raise "Not supported `mode`"
......@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = []
label_list = []
category_list = []
print('Loading data from split ' + self.mode)
print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float)
data_list.append(data[:, 0:self.dim])
data = np.array(
[t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list
self.label = label_list
self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x
def __len__(self):
return len(self.data)
def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
x = self.data[i][inds,:self.dim]
inds = np.random.choice(
self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds]
cat = self.category[i]
if self.mode == 'train':
if self.mode == "train":
x = self.translate(x, size=self.dim)
x = x.astype(np.float)
y = y.astype(np.int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment