Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
...@@ -6,7 +6,8 @@ import dgl ...@@ -6,7 +6,8 @@ import dgl
import dgl.function as fn import dgl.function as fn
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
#Semantic attention in the metapath-based aggregation (the same as that in the HAN)
# Semantic attention in the metapath-based aggregation (the same as that in the HAN)
class SemanticAttention(nn.Module): class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128): def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__() super(SemanticAttention, self).__init__()
...@@ -14,179 +15,232 @@ class SemanticAttention(nn.Module): ...@@ -14,179 +15,232 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential( self.project = nn.Sequential(
nn.Linear(in_size, hidden_size), nn.Linear(in_size, hidden_size),
nn.Tanh(), nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False) nn.Linear(hidden_size, 1, bias=False),
) )
def forward(self, z): def forward(self, z):
''' """
Shape of z: (N, M , D*K) Shape of z: (N, M , D*K)
N: number of nodes N: number of nodes
M: number of metapath patterns M: number of metapath patterns
D: hidden_size D: hidden_size
K: number of heads K: number of heads
''' """
w = self.project(z).mean(0) # (M, 1) w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1) beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
return (beta * z).sum(1) # (N, D * K)
#Metapath-based aggregation (the same as the HANLayer) # Metapath-based aggregation (the same as the HANLayer)
class HANLayer(nn.Module): class HANLayer(nn.Module):
def __init__(self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout): def __init__(
self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__() super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix # One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
for i in range(len(meta_path_patterns)): for i in range(len(meta_path_patterns)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, self.gat_layers.append(
dropout, dropout, activation=F.elu, GATConv(
allow_zero_in_degree=True)) in_size,
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) out_size,
self.meta_path_patterns = list(tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns) layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.meta_path_patterns = list(
tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns
)
self._cached_graph = None self._cached_graph = None
self._cached_coalesced_graph = {} self._cached_coalesced_graph = {}
def forward(self, g, h): def forward(self, g, h):
semantic_embeddings = [] semantic_embeddings = []
#obtain metapath reachable graph # obtain metapath reachable graph
if self._cached_graph is None or self._cached_graph is not g: if self._cached_graph is None or self._cached_graph is not g:
self._cached_graph = g self._cached_graph = g
self._cached_coalesced_graph.clear() self._cached_coalesced_graph.clear()
for meta_path_pattern in self.meta_path_patterns: for meta_path_pattern in self.meta_path_patterns:
self._cached_coalesced_graph[meta_path_pattern] = dgl.metapath_reachable_graph( self._cached_coalesced_graph[
g, meta_path_pattern) meta_path_pattern
] = dgl.metapath_reachable_graph(g, meta_path_pattern)
for i, meta_path_pattern in enumerate(self.meta_path_patterns): for i, meta_path_pattern in enumerate(self.meta_path_patterns):
new_g = self._cached_coalesced_graph[meta_path_pattern] new_g = self._cached_coalesced_graph[meta_path_pattern]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
#Relational neighbor aggregation # Relational neighbor aggregation
class RelationalAGG(nn.Module): class RelationalAGG(nn.Module):
def __init__(self, g, in_size, out_size, dropout=0.1): def __init__(self, g, in_size, out_size, dropout=0.1):
super(RelationalAGG, self).__init__() super(RelationalAGG, self).__init__()
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
#Transform weights for different types of edges # Transform weights for different types of edges
self.W_T = nn.ModuleDict({ self.W_T = nn.ModuleDict(
name : nn.Linear(in_size, out_size, bias = False) for name in g.etypes {
}) name: nn.Linear(in_size, out_size, bias=False)
for name in g.etypes
}
)
#Attention weights for different types of edges # Attention weights for different types of edges
self.W_A = nn.ModuleDict({ self.W_A = nn.ModuleDict(
name : nn.Linear(out_size, 1, bias = False) for name in g.etypes {name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}
}) )
#layernorm # layernorm
self.layernorm = nn.LayerNorm(out_size) self.layernorm = nn.LayerNorm(out_size)
#dropout layer # dropout layer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, g, feat_dict): def forward(self, g, feat_dict):
funcs={} funcs = {}
for srctype, etype, dsttype in g.canonical_etypes: for srctype, etype, dsttype in g.canonical_etypes:
g.nodes[dsttype].data['h'] = feat_dict[dsttype] #nodes' original feature g.nodes[dsttype].data["h"] = feat_dict[
g.nodes[srctype].data['h'] = feat_dict[srctype] dsttype
g.nodes[srctype].data['t_h'] = self.W_T[etype](feat_dict[srctype]) #src nodes' transformed feature ] # nodes' original feature
g.nodes[srctype].data["h"] = feat_dict[srctype]
#compute the attention numerator (exp) g.nodes[srctype].data["t_h"] = self.W_T[etype](
g.apply_edges(fn.u_mul_v('t_h','h','x'),etype=etype) feat_dict[srctype]
g.edges[etype].data['x'] = torch.exp(self.W_A[etype](g.edges[etype].data['x'])) ) # src nodes' transformed feature
#first update to compute the attention denominator (\sum exp) # compute the attention numerator (exp)
funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att')) g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype)
g.multi_update_all(funcs, 'sum') g.edges[etype].data["x"] = torch.exp(
self.W_A[etype](g.edges[etype].data["x"])
funcs={} )
# first update to compute the attention denominator (\sum exp)
funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att"))
g.multi_update_all(funcs, "sum")
funcs = {}
for srctype, etype, dsttype in g.canonical_etypes: for srctype, etype, dsttype in g.canonical_etypes:
g.apply_edges(fn.e_div_v('x', 'att', 'att'),etype=etype) #compute attention weights (numerator/denominator) g.apply_edges(
funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h')) #\sum(h0*att) -> h1 fn.e_div_v("x", "att", "att"), etype=etype
#second update to obtain h1 ) # compute attention weights (numerator/denominator)
g.multi_update_all(funcs, 'sum') funcs[etype] = (
fn.u_mul_e("h", "att", "m"),
#apply activation, layernorm, and dropout fn.sum("m", "h"),
feat_dict={} ) # \sum(h0*att) -> h1
# second update to obtain h1
g.multi_update_all(funcs, "sum")
# apply activation, layernorm, and dropout
feat_dict = {}
for ntype in g.ntypes: for ntype in g.ntypes:
feat_dict[ntype] = self.dropout(self.layernorm(F.relu_(g.nodes[ntype].data['h']))) #apply activation, layernorm, and dropout feat_dict[ntype] = self.dropout(
self.layernorm(F.relu_(g.nodes[ntype].data["h"]))
) # apply activation, layernorm, and dropout
return feat_dict return feat_dict
class TAHIN(nn.Module): class TAHIN(nn.Module):
def __init__(self, g, meta_path_patterns, in_size, out_size, num_heads, dropout): def __init__(
self, g, meta_path_patterns, in_size, out_size, num_heads, dropout
):
super(TAHIN, self).__init__() super(TAHIN, self).__init__()
#embeddings for different types of nodes, h0 # embeddings for different types of nodes, h0
self.initializer = nn.init.xavier_uniform_ self.initializer = nn.init.xavier_uniform_
self.feature_dict = nn.ParameterDict({ self.feature_dict = nn.ParameterDict(
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes {
}) ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
#relational neighbor aggregation, this produces h1 # relational neighbor aggregation, this produces h1
self.RelationalAGG = RelationalAGG(g, in_size, out_size) self.RelationalAGG = RelationalAGG(g, in_size, out_size)
#metapath-based aggregation modules for user and item, this produces h2 # metapath-based aggregation modules for user and item, this produces h2
self.meta_path_patterns = meta_path_patterns self.meta_path_patterns = meta_path_patterns
#one HANLayer for user, one HANLayer for item # one HANLayer for user, one HANLayer for item
self.hans = nn.ModuleDict({ self.hans = nn.ModuleDict(
key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items() {
}) key: HANLayer(value, in_size, out_size, num_heads, dropout)
for key, value in self.meta_path_patterns.items()
#layers to combine h0, h1, and h2 }
#used to update node embeddings )
self.user_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
self.user_layer2 = nn.Linear(2*out_size, out_size, bias=True) # layers to combine h0, h1, and h2
self.item_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True) # used to update node embeddings
self.item_layer2 = nn.Linear(2*out_size, out_size, bias=True) self.user_layer1 = nn.Linear(
(num_heads + 1) * out_size, out_size, bias=True
#layernorm )
self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
self.item_layer1 = nn.Linear(
(num_heads + 1) * out_size, out_size, bias=True
)
self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
# layernorm
self.layernorm = nn.LayerNorm(out_size) self.layernorm = nn.LayerNorm(out_size)
#network to score the node pairs # network to score the node pairs
self.pred = nn.Linear(out_size, out_size) self.pred = nn.Linear(out_size, out_size)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(out_size, 1) self.fc = nn.Linear(out_size, 1)
def forward(self, g, user_key, item_key, user_idx, item_idx): def forward(self, g, user_key, item_key, user_idx, item_idx):
#relational neighbor aggregation, h1 # relational neighbor aggregation, h1
h1 = self.RelationalAGG(g, self.feature_dict) h1 = self.RelationalAGG(g, self.feature_dict)
#metapath-based aggregation, h2 # metapath-based aggregation, h2
h2 = {} h2 = {}
for key in self.meta_path_patterns.keys(): for key in self.meta_path_patterns.keys():
h2[key] = self.hans[key](g, self.feature_dict[key]) h2[key] = self.hans[key](g, self.feature_dict[key])
#update node embeddings # update node embeddings
user_emb = torch.cat((h1[user_key], h2[user_key]), 1) user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
item_emb = torch.cat((h1[item_key], h2[item_key]), 1) item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
user_emb = self.user_layer1(user_emb) user_emb = self.user_layer1(user_emb)
item_emb = self.item_layer1(item_emb) item_emb = self.item_layer1(item_emb)
user_emb = self.user_layer2(torch.cat((user_emb, self.feature_dict[user_key]), 1)) user_emb = self.user_layer2(
item_emb = self.item_layer2(torch.cat((item_emb, self.feature_dict[item_key]), 1)) torch.cat((user_emb, self.feature_dict[user_key]), 1)
)
item_emb = self.item_layer2(
torch.cat((item_emb, self.feature_dict[item_key]), 1)
)
#Relu # Relu
user_emb = F.relu_(user_emb) user_emb = F.relu_(user_emb)
item_emb = F.relu_(item_emb) item_emb = F.relu_(item_emb)
#layer norm # layer norm
user_emb = self.layernorm(user_emb) user_emb = self.layernorm(user_emb)
item_emb = self.layernorm(item_emb) item_emb = self.layernorm(item_emb)
#obtain users/items embeddings and their interactions # obtain users/items embeddings and their interactions
user_feat = user_emb[user_idx] user_feat = user_emb[user_idx]
item_feat = item_emb[item_idx] item_feat = item_emb[item_idx]
interaction = user_feat*item_feat interaction = user_feat * item_feat
#score the node pairs # score the node pairs
pred = self.pred(interaction) pred = self.pred(interaction)
pred = self.dropout(pred) #dropout pred = self.dropout(pred) # dropout
pred = self.fc(pred) pred = self.fc(pred)
pred = torch.sigmoid(pred) pred = torch.sigmoid(pred)
return pred.squeeze(1) return pred.squeeze(1)
\ No newline at end of file
This diff is collapsed.
import torch import argparse
import pickle as pkl
import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import dgl
import numpy as np
import pickle as pkl
import argparse
from data_loader import load_data from data_loader import load_data
from TAHIN import TAHIN from TAHIN import TAHIN
from utils import evaluate_auc, evaluate_acc, evaluate_f1_score, evaluate_logloss from utils import (
evaluate_acc,
evaluate_auc,
evaluate_f1_score,
evaluate_logloss,
)
import dgl
def main(args): def main(args):
#step 1: Check device # step 1: Check device
if args.gpu >= 0 and torch.cuda.is_available(): if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
#step 2: Load data # step 2: Load data
g, train_loader, eval_loader, test_loader, meta_paths, user_key, item_key = load_data(args.dataset, args.batch, args.num_workers, args.path) (
g,
train_loader,
eval_loader,
test_loader,
meta_paths,
user_key,
item_key,
) = load_data(args.dataset, args.batch, args.num_workers, args.path)
g = g.to(device) g = g.to(device)
print('Data loaded.') print("Data loaded.")
#step 3: Create model and training components # step 3: Create model and training components
model = TAHIN( model = TAHIN(
g, meta_paths, args.in_size, args.out_size, args.num_heads, args.dropout g, meta_paths, args.in_size, args.out_size, args.num_heads, args.dropout
) )
model = model.to(device) model = model.to(device)
criterion = nn.BCELoss() criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
print('Model created.') print("Model created.")
#step 4: Training # step 4: Training
print('Start training.') print("Start training.")
best_acc = 0.0 best_acc = 0.0
kill_cnt = 0 kill_cnt = 0
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -63,33 +79,39 @@ def main(args): ...@@ -63,33 +79,39 @@ def main(args):
# compute loss # compute loss
val_loss = criterion(logits, label) val_loss = criterion(logits, label)
val_acc = evaluate_acc(logits.detach().cpu().numpy(), label.detach().cpu().numpy()) val_acc = evaluate_acc(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
validate_loss.append(val_loss) validate_loss.append(val_loss)
validate_acc.append(val_acc) validate_acc.append(val_acc)
validate_loss = torch.stack(validate_loss).sum().cpu().item() validate_loss = torch.stack(validate_loss).sum().cpu().item()
validate_acc = np.mean(validate_acc) validate_acc = np.mean(validate_acc)
#validate # validate
if validate_acc > best_acc: if validate_acc > best_acc:
best_acc = validate_acc best_acc = validate_acc
best_epoch = epoch best_epoch = epoch
torch.save(model.state_dict(), 'TAHIN'+'_'+args.dataset) torch.save(model.state_dict(), "TAHIN" + "_" + args.dataset)
kill_cnt = 0 kill_cnt = 0
print("saving model...") print("saving model...")
else: else:
kill_cnt += 1 kill_cnt += 1
if kill_cnt > args.early_stop: if kill_cnt > args.early_stop:
print('early stop.') print("early stop.")
print("best epoch:{}".format(best_epoch)) print("best epoch:{}".format(best_epoch))
break break
print("In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\n, Valid ACC: {:.5}".format(epoch, train_loss, validate_loss, validate_acc)) print(
"In epoch {}, Train Loss: {:.4f}, Valid Loss: {:.5}\n, Valid ACC: {:.5}".format(
epoch, train_loss, validate_loss, validate_acc
)
)
#test use the best model # test use the best model
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
model.load_state_dict(torch.load('TAHIN'+'_'+args.dataset)) model.load_state_dict(torch.load("TAHIN" + "_" + args.dataset))
test_loss = [] test_loss = []
test_acc = [] test_acc = []
test_auc = [] test_auc = []
...@@ -101,11 +123,19 @@ def main(args): ...@@ -101,11 +123,19 @@ def main(args):
# compute loss # compute loss
loss = criterion(logits, label) loss = criterion(logits, label)
acc = evaluate_acc(logits.detach().cpu().numpy(), label.detach().cpu().numpy()) acc = evaluate_acc(
auc = evaluate_auc(logits.detach().cpu().numpy(), label.detach().cpu().numpy()) logits.detach().cpu().numpy(), label.detach().cpu().numpy()
f1 = evaluate_f1_score(logits.detach().cpu().numpy(), label.detach().cpu().numpy()) )
log_loss = evaluate_logloss(logits.detach().cpu().numpy(), label.detach().cpu().numpy()) auc = evaluate_auc(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
f1 = evaluate_f1_score(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
log_loss = evaluate_logloss(
logits.detach().cpu().numpy(), label.detach().cpu().numpy()
)
test_loss.append(loss) test_loss.append(loss)
test_acc.append(acc) test_acc.append(acc)
test_auc.append(auc) test_auc.append(auc)
...@@ -117,33 +147,73 @@ def main(args): ...@@ -117,33 +147,73 @@ def main(args):
test_auc = np.mean(test_auc) test_auc = np.mean(test_auc)
test_f1 = np.mean(test_f1) test_f1 = np.mean(test_f1)
test_logloss = np.mean(test_logloss) test_logloss = np.mean(test_logloss)
print("Test Loss: {:.5}\n, Test ACC: {:.5}\n, AUC: {:.5}\n, F1: {:.5}\n, Logloss: {:.5}\n".format(test_loss, test_acc, test_auc, test_f1, test_logloss)) print(
"Test Loss: {:.5}\n, Test ACC: {:.5}\n, AUC: {:.5}\n, F1: {:.5}\n, Logloss: {:.5}\n".format(
test_loss, test_acc, test_auc, test_f1, test_logloss
if __name__ == '__main__': )
parser = argparse.ArgumentParser(description='Parser For Arguments', formatter_class=argparse.ArgumentDefaultsHelpFormatter) )
parser.add_argument('--dataset', default='movielens', help='Dataset to use, default: movielens')
parser.add_argument('--path', default='./data', help='Path to save the data') if __name__ == "__main__":
parser.add_argument('--model', default='TAHIN', help='Model Name') parser = argparse.ArgumentParser(
description="Parser For Arguments",
parser.add_argument('--batch', default=128, type=int, help='Batch size') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--gpu', type=int, default='0', help='Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0') )
parser.add_argument('--epochs', type=int, default=500, help='Maximum number of epochs')
parser.add_argument('--wd', type=float, default=0, help='L2 Regularization for Optimizer') parser.add_argument(
parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate') "--dataset",
parser.add_argument('--num_workers', type=int, default=10, help='Number of processes to construct batches') default="movielens",
parser.add_argument('--early_stop', default=15, type=int, help='Patience for early stop.') help="Dataset to use, default: movielens",
)
parser.add_argument('--in_size', default=128, type=int, help='Initial dimension size for entities.') parser.add_argument(
parser.add_argument('--out_size', default=128, type=int, help='Output dimension size for entities.') "--path", default="./data", help="Path to save the data"
)
parser.add_argument('--num_heads', default=1, type=int, help='Number of attention heads') parser.add_argument("--model", default="TAHIN", help="Model Name")
parser.add_argument('--dropout', default=0.1, type=float, help='Dropout.')
parser.add_argument("--batch", default=128, type=int, help="Batch size")
parser.add_argument(
"--gpu",
type=int,
default="0",
help="Set GPU Ids : Eg: For CPU = -1, For Single GPU = 0",
)
parser.add_argument(
"--epochs", type=int, default=500, help="Maximum number of epochs"
)
parser.add_argument(
"--wd", type=float, default=0, help="L2 Regularization for Optimizer"
)
parser.add_argument("--lr", type=float, default=0.001, help="Learning Rate")
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of processes to construct batches",
)
parser.add_argument(
"--early_stop", default=15, type=int, help="Patience for early stop."
)
parser.add_argument(
"--in_size",
default=128,
type=int,
help="Initial dimension size for entities.",
)
parser.add_argument(
"--out_size",
default=128,
type=int,
help="Output dimension size for entities.",
)
parser.add_argument(
"--num_heads", default=1, type=int, help="Number of attention heads"
)
parser.add_argument("--dropout", default=0.1, type=float, help="Dropout.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
import numpy as np import numpy as np
import torch import torch
from sklearn.metrics import roc_auc_score, accuracy_score, log_loss, f1_score, average_precision_score, ndcg_score from sklearn.metrics import (
accuracy_score,
average_precision_score,
f1_score,
log_loss,
ndcg_score,
roc_auc_score,
)
def evaluate_auc(pred, label): def evaluate_auc(pred, label):
res=roc_auc_score(y_score=pred, y_true=label) res = roc_auc_score(y_score=pred, y_true=label)
return res return res
def evaluate_acc(pred, label): def evaluate_acc(pred, label):
res = [] res = []
for _value in pred: for _value in pred:
...@@ -15,6 +24,7 @@ def evaluate_acc(pred, label): ...@@ -15,6 +24,7 @@ def evaluate_acc(pred, label):
res.append(0) res.append(0)
return accuracy_score(y_pred=res, y_true=label) return accuracy_score(y_pred=res, y_true=label)
def evaluate_f1_score(pred, label): def evaluate_f1_score(pred, label):
res = [] res = []
for _value in pred: for _value in pred:
...@@ -24,7 +34,7 @@ def evaluate_f1_score(pred, label): ...@@ -24,7 +34,7 @@ def evaluate_f1_score(pred, label):
res.append(0) res.append(0)
return f1_score(y_pred=res, y_true=label) return f1_score(y_pred=res, y_true=label)
def evaluate_logloss(pred, label): def evaluate_logloss(pred, label):
res = log_loss(y_true=label, y_pred=pred,eps=1e-7, normalize=True) res = log_loss(y_true=label, y_pred=pred, eps=1e-7, normalize=True)
return res return res
\ No newline at end of file
import pickle
import argparse import argparse
import os
import pickle
import evaluation
import layers
import numpy as np import numpy as np
import sampler as sampler_module
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
import torchtext import torchtext
import dgl
import os
import tqdm import tqdm
from torch.utils.data import DataLoader
import layers
import sampler as sampler_module
import evaluation
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator from torchtext.vocab import build_vocab_from_iterator
import dgl
class PinSAGEModel(nn.Module): class PinSAGEModel(nn.Module):
def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers): def __init__(self, full_graph, ntype, textsets, hidden_dims, n_layers):
super().__init__() super().__init__()
self.proj = layers.LinearProjector(full_graph, ntype, textsets, hidden_dims) self.proj = layers.LinearProjector(
full_graph, ntype, textsets, hidden_dims
)
self.sage = layers.SAGENet(hidden_dims, n_layers) self.sage = layers.SAGENet(hidden_dims, n_layers)
self.scorer = layers.ItemToItemScorer(full_graph, ntype) self.scorer = layers.ItemToItemScorer(full_graph, ntype)
...@@ -34,22 +38,23 @@ class PinSAGEModel(nn.Module): ...@@ -34,22 +38,23 @@ class PinSAGEModel(nn.Module):
h_item_dst = self.proj(blocks[-1].dstdata) h_item_dst = self.proj(blocks[-1].dstdata)
return h_item_dst + self.sage(blocks, h_item) return h_item_dst + self.sage(blocks, h_item)
def train(dataset, args): def train(dataset, args):
g = dataset['train-graph'] g = dataset["train-graph"]
val_matrix = dataset['val-matrix'].tocsr() val_matrix = dataset["val-matrix"].tocsr()
test_matrix = dataset['test-matrix'].tocsr() test_matrix = dataset["test-matrix"].tocsr()
item_texts = dataset['item-texts'] item_texts = dataset["item-texts"]
user_ntype = dataset['user-type'] user_ntype = dataset["user-type"]
item_ntype = dataset['item-type'] item_ntype = dataset["item-type"]
user_to_item_etype = dataset['user-to-item-type'] user_to_item_etype = dataset["user-to-item-type"]
timestamp = dataset['timestamp-edge-column'] timestamp = dataset["timestamp-edge-column"]
device = torch.device(args.device) device = torch.device(args.device)
# Assign user and movie IDs and use them as features (to learn an individual trainable # Assign user and movie IDs and use them as features (to learn an individual trainable
# embedding for each entity) # embedding for each entity)
g.nodes[user_ntype].data['id'] = torch.arange(g.num_nodes(user_ntype)) g.nodes[user_ntype].data["id"] = torch.arange(g.num_nodes(user_ntype))
g.nodes[item_ntype].data['id'] = torch.arange(g.num_nodes(item_ntype)) g.nodes[item_ntype].data["id"] = torch.arange(g.num_nodes(item_ntype))
# Prepare torchtext dataset and Vocabulary # Prepare torchtext dataset and Vocabulary
textset = {} textset = {}
...@@ -63,30 +68,50 @@ def train(dataset, args): ...@@ -63,30 +68,50 @@ def train(dataset, args):
l = tokenizer(item_texts[key][i].lower()) l = tokenizer(item_texts[key][i].lower())
textlist.append(l) textlist.append(l)
for key, field in item_texts.items(): for key, field in item_texts.items():
vocab2 = build_vocab_from_iterator(textlist, specials=["<unk>","<pad>"]) vocab2 = build_vocab_from_iterator(
textset[key] = (textlist, vocab2, vocab2.get_stoi()['<pad>'], batch_first) textlist, specials=["<unk>", "<pad>"]
)
textset[key] = (
textlist,
vocab2,
vocab2.get_stoi()["<pad>"],
batch_first,
)
# Sampler # Sampler
batch_sampler = sampler_module.ItemToItemBatchSampler( batch_sampler = sampler_module.ItemToItemBatchSampler(
g, user_ntype, item_ntype, args.batch_size) g, user_ntype, item_ntype, args.batch_size
)
neighbor_sampler = sampler_module.NeighborSampler( neighbor_sampler = sampler_module.NeighborSampler(
g, user_ntype, item_ntype, args.random_walk_length, g,
args.random_walk_restart_prob, args.num_random_walks, args.num_neighbors, user_ntype,
args.num_layers) item_ntype,
collator = sampler_module.PinSAGECollator(neighbor_sampler, g, item_ntype, textset) args.random_walk_length,
args.random_walk_restart_prob,
args.num_random_walks,
args.num_neighbors,
args.num_layers,
)
collator = sampler_module.PinSAGECollator(
neighbor_sampler, g, item_ntype, textset
)
dataloader = DataLoader( dataloader = DataLoader(
batch_sampler, batch_sampler,
collate_fn=collator.collate_train, collate_fn=collator.collate_train,
num_workers=args.num_workers) num_workers=args.num_workers,
)
dataloader_test = DataLoader( dataloader_test = DataLoader(
torch.arange(g.num_nodes(item_ntype)), torch.arange(g.num_nodes(item_ntype)),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=collator.collate_test, collate_fn=collator.collate_test,
num_workers=args.num_workers) num_workers=args.num_workers,
)
dataloader_it = iter(dataloader) dataloader_it = iter(dataloader)
# Model # Model
model = PinSAGEModel(g, item_ntype, textset, args.hidden_dims, args.num_layers).to(device) model = PinSAGEModel(
g, item_ntype, textset, args.hidden_dims, args.num_layers
).to(device)
# Optimizer # Optimizer
opt = torch.optim.Adam(model.parameters(), lr=args.lr) opt = torch.optim.Adam(model.parameters(), lr=args.lr)
...@@ -109,7 +134,9 @@ def train(dataset, args): ...@@ -109,7 +134,9 @@ def train(dataset, args):
# Evaluate # Evaluate
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
item_batches = torch.arange(g.num_nodes(item_ntype)).split(args.batch_size) item_batches = torch.arange(g.num_nodes(item_ntype)).split(
args.batch_size
)
h_item_batches = [] h_item_batches = []
for blocks in dataloader_test: for blocks in dataloader_test:
for i in range(len(blocks)): for i in range(len(blocks)):
...@@ -118,32 +145,37 @@ def train(dataset, args): ...@@ -118,32 +145,37 @@ def train(dataset, args):
h_item_batches.append(model.get_repr(blocks)) h_item_batches.append(model.get_repr(blocks))
h_item = torch.cat(h_item_batches, 0) h_item = torch.cat(h_item_batches, 0)
print(evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)) print(
evaluation.evaluate_nn(dataset, h_item, args.k, args.batch_size)
)
if __name__ == '__main__': if __name__ == "__main__":
# Arguments # Arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_path', type=str) parser.add_argument("dataset_path", type=str)
parser.add_argument('--random-walk-length', type=int, default=2) parser.add_argument("--random-walk-length", type=int, default=2)
parser.add_argument('--random-walk-restart-prob', type=float, default=0.5) parser.add_argument("--random-walk-restart-prob", type=float, default=0.5)
parser.add_argument('--num-random-walks', type=int, default=10) parser.add_argument("--num-random-walks", type=int, default=10)
parser.add_argument('--num-neighbors', type=int, default=3) parser.add_argument("--num-neighbors", type=int, default=3)
parser.add_argument('--num-layers', type=int, default=2) parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument('--hidden-dims', type=int, default=16) parser.add_argument("--hidden-dims", type=int, default=16)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--device', type=str, default='cpu') # can also be "cuda:0" parser.add_argument(
parser.add_argument('--num-epochs', type=int, default=1) "--device", type=str, default="cpu"
parser.add_argument('--batches-per-epoch', type=int, default=20000) ) # can also be "cuda:0"
parser.add_argument('--num-workers', type=int, default=0) parser.add_argument("--num-epochs", type=int, default=1)
parser.add_argument('--lr', type=float, default=3e-5) parser.add_argument("--batches-per-epoch", type=int, default=20000)
parser.add_argument('-k', type=int, default=10) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("-k", type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
# Load dataset # Load dataset
data_info_path = os.path.join(args.dataset_path, 'data.pkl') data_info_path = os.path.join(args.dataset_path, "data.pkl")
with open(data_info_path, 'rb') as f: with open(data_info_path, "rb") as f:
dataset = pickle.load(f) dataset = pickle.load(f)
train_g_path = os.path.join(args.dataset_path, 'train_g.bin') train_g_path = os.path.join(args.dataset_path, "train_g.bin")
g_list, _ = dgl.load_graphs(train_g_path) g_list, _ = dgl.load_graphs(train_g_path)
dataset['train-graph'] = g_list[0] dataset["train-graph"] = g_list[0]
train(dataset, args) train(dataset, args)
...@@ -12,23 +12,25 @@ instead put variable-length features into a more suitable container ...@@ -12,23 +12,25 @@ instead put variable-length features into a more suitable container
(e.g. torchtext to handle list of texts) (e.g. torchtext to handle list of texts)
""" """
import os
import re
import argparse import argparse
import os
import pickle import pickle
import pandas as pd import re
import numpy as np import numpy as np
import pandas as pd
import scipy.sparse as ssp import scipy.sparse as ssp
import dgl
import torch import torch
import torchtext import torchtext
from builder import PandasGraphBuilder from builder import PandasGraphBuilder
from data_utils import * from data_utils import *
if __name__ == '__main__': import dgl
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str) parser.add_argument("directory", type=str)
parser.add_argument('out_directory', type=str) parser.add_argument("out_directory", type=str)
args = parser.parse_args() args = parser.parse_args()
directory = args.directory directory = args.directory
out_directory = args.out_directory out_directory = args.out_directory
...@@ -38,97 +40,133 @@ if __name__ == '__main__': ...@@ -38,97 +40,133 @@ if __name__ == '__main__':
# Load data # Load data
users = [] users = []
with open(os.path.join(directory, 'users.dat'), encoding='latin1') as f: with open(os.path.join(directory, "users.dat"), encoding="latin1") as f:
for l in f: for l in f:
id_, gender, age, occupation, zip_ = l.strip().split('::') id_, gender, age, occupation, zip_ = l.strip().split("::")
users.append({ users.append(
'user_id': int(id_), {
'gender': gender, "user_id": int(id_),
'age': age, "gender": gender,
'occupation': occupation, "age": age,
'zip': zip_, "occupation": occupation,
}) "zip": zip_,
users = pd.DataFrame(users).astype('category') }
)
users = pd.DataFrame(users).astype("category")
movies = [] movies = []
with open(os.path.join(directory, 'movies.dat'), encoding='latin1') as f: with open(os.path.join(directory, "movies.dat"), encoding="latin1") as f:
for l in f: for l in f:
id_, title, genres = l.strip().split('::') id_, title, genres = l.strip().split("::")
genres_set = set(genres.split('|')) genres_set = set(genres.split("|"))
# extract year # extract year
assert re.match(r'.*\([0-9]{4}\)$', title) assert re.match(r".*\([0-9]{4}\)$", title)
year = title[-5:-1] year = title[-5:-1]
title = title[:-6].strip() title = title[:-6].strip()
data = {'movie_id': int(id_), 'title': title, 'year': year} data = {"movie_id": int(id_), "title": title, "year": year}
for g in genres_set: for g in genres_set:
data[g] = True data[g] = True
movies.append(data) movies.append(data)
movies = pd.DataFrame(movies).astype({'year': 'category'}) movies = pd.DataFrame(movies).astype({"year": "category"})
ratings = [] ratings = []
with open(os.path.join(directory, 'ratings.dat'), encoding='latin1') as f: with open(os.path.join(directory, "ratings.dat"), encoding="latin1") as f:
for l in f: for l in f:
user_id, movie_id, rating, timestamp = [int(_) for _ in l.split('::')] user_id, movie_id, rating, timestamp = [
ratings.append({ int(_) for _ in l.split("::")
'user_id': user_id, ]
'movie_id': movie_id, ratings.append(
'rating': rating, {
'timestamp': timestamp, "user_id": user_id,
}) "movie_id": movie_id,
"rating": rating,
"timestamp": timestamp,
}
)
ratings = pd.DataFrame(ratings) ratings = pd.DataFrame(ratings)
# Filter the users and items that never appear in the rating table. # Filter the users and items that never appear in the rating table.
distinct_users_in_ratings = ratings['user_id'].unique() distinct_users_in_ratings = ratings["user_id"].unique()
distinct_movies_in_ratings = ratings['movie_id'].unique() distinct_movies_in_ratings = ratings["movie_id"].unique()
users = users[users['user_id'].isin(distinct_users_in_ratings)] users = users[users["user_id"].isin(distinct_users_in_ratings)]
movies = movies[movies['movie_id'].isin(distinct_movies_in_ratings)] movies = movies[movies["movie_id"].isin(distinct_movies_in_ratings)]
# Group the movie features into genres (a vector), year (a category), title (a string) # Group the movie features into genres (a vector), year (a category), title (a string)
genre_columns = movies.columns.drop(['movie_id', 'title', 'year']) genre_columns = movies.columns.drop(["movie_id", "title", "year"])
movies[genre_columns] = movies[genre_columns].fillna(False).astype('bool') movies[genre_columns] = movies[genre_columns].fillna(False).astype("bool")
movies_categorical = movies.drop('title', axis=1) movies_categorical = movies.drop("title", axis=1)
# Build graph # Build graph
graph_builder = PandasGraphBuilder() graph_builder = PandasGraphBuilder()
graph_builder.add_entities(users, 'user_id', 'user') graph_builder.add_entities(users, "user_id", "user")
graph_builder.add_entities(movies_categorical, 'movie_id', 'movie') graph_builder.add_entities(movies_categorical, "movie_id", "movie")
graph_builder.add_binary_relations(ratings, 'user_id', 'movie_id', 'watched') graph_builder.add_binary_relations(
graph_builder.add_binary_relations(ratings, 'movie_id', 'user_id', 'watched-by') ratings, "user_id", "movie_id", "watched"
)
graph_builder.add_binary_relations(
ratings, "movie_id", "user_id", "watched-by"
)
g = graph_builder.build() g = graph_builder.build()
# Assign features. # Assign features.
# Note that variable-sized features such as texts or images are handled elsewhere. # Note that variable-sized features such as texts or images are handled elsewhere.
g.nodes['user'].data['gender'] = torch.LongTensor(users['gender'].cat.codes.values) g.nodes["user"].data["gender"] = torch.LongTensor(
g.nodes['user'].data['age'] = torch.LongTensor(users['age'].cat.codes.values) users["gender"].cat.codes.values
g.nodes['user'].data['occupation'] = torch.LongTensor(users['occupation'].cat.codes.values) )
g.nodes['user'].data['zip'] = torch.LongTensor(users['zip'].cat.codes.values) g.nodes["user"].data["age"] = torch.LongTensor(
users["age"].cat.codes.values
g.nodes['movie'].data['year'] = torch.LongTensor(movies['year'].cat.codes.values) )
g.nodes['movie'].data['genre'] = torch.FloatTensor(movies[genre_columns].values) g.nodes["user"].data["occupation"] = torch.LongTensor(
users["occupation"].cat.codes.values
g.edges['watched'].data['rating'] = torch.LongTensor(ratings['rating'].values) )
g.edges['watched'].data['timestamp'] = torch.LongTensor(ratings['timestamp'].values) g.nodes["user"].data["zip"] = torch.LongTensor(
g.edges['watched-by'].data['rating'] = torch.LongTensor(ratings['rating'].values) users["zip"].cat.codes.values
g.edges['watched-by'].data['timestamp'] = torch.LongTensor(ratings['timestamp'].values) )
g.nodes["movie"].data["year"] = torch.LongTensor(
movies["year"].cat.codes.values
)
g.nodes["movie"].data["genre"] = torch.FloatTensor(
movies[genre_columns].values
)
g.edges["watched"].data["rating"] = torch.LongTensor(
ratings["rating"].values
)
g.edges["watched"].data["timestamp"] = torch.LongTensor(
ratings["timestamp"].values
)
g.edges["watched-by"].data["rating"] = torch.LongTensor(
ratings["rating"].values
)
g.edges["watched-by"].data["timestamp"] = torch.LongTensor(
ratings["timestamp"].values
)
# Train-validation-test split # Train-validation-test split
# This is a little bit tricky as we want to select the last interaction for test, and the # This is a little bit tricky as we want to select the last interaction for test, and the
# second-to-last interaction for validation. # second-to-last interaction for validation.
train_indices, val_indices, test_indices = train_test_split_by_time(ratings, 'timestamp', 'user_id') train_indices, val_indices, test_indices = train_test_split_by_time(
ratings, "timestamp", "user_id"
)
# Build the graph with training interactions only. # Build the graph with training interactions only.
train_g = build_train_graph(g, train_indices, 'user', 'movie', 'watched', 'watched-by') train_g = build_train_graph(
assert train_g.out_degrees(etype='watched').min() > 0 g, train_indices, "user", "movie", "watched", "watched-by"
)
assert train_g.out_degrees(etype="watched").min() > 0
# Build the user-item sparse matrix for validation and test set. # Build the user-item sparse matrix for validation and test set.
val_matrix, test_matrix = build_val_test_matrix(g, val_indices, test_indices, 'user', 'movie', 'watched') val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, "user", "movie", "watched"
)
## Build title set ## Build title set
movie_textual_dataset = {'title': movies['title'].values} movie_textual_dataset = {"title": movies["title"].values}
# The model should build their own vocabulary and process the texts. Here is one example # The model should build their own vocabulary and process the texts. Here is one example
# of using torchtext to pad and numericalize a batch of strings. # of using torchtext to pad and numericalize a batch of strings.
...@@ -140,18 +178,19 @@ if __name__ == '__main__': ...@@ -140,18 +178,19 @@ if __name__ == '__main__':
## Dump the graph and the datasets ## Dump the graph and the datasets
dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g) dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
dataset = { dataset = {
'val-matrix': val_matrix, "val-matrix": val_matrix,
'test-matrix': test_matrix, "test-matrix": test_matrix,
'item-texts': movie_textual_dataset, "item-texts": movie_textual_dataset,
'item-images': None, "item-images": None,
'user-type': 'user', "user-type": "user",
'item-type': 'movie', "item-type": "movie",
'user-to-item-type': 'watched', "user-to-item-type": "watched",
'item-to-user-type': 'watched-by', "item-to-user-type": "watched-by",
'timestamp-edge-column': 'timestamp'} "timestamp-edge-column": "timestamp",
}
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f:
with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
pickle.dump(dataset, f) pickle.dump(dataset, f)
...@@ -3,76 +3,102 @@ Script that reads from raw Nowplaying-RS data and dumps into a pickle ...@@ -3,76 +3,102 @@ Script that reads from raw Nowplaying-RS data and dumps into a pickle
file a heterogeneous graph with categorical and numeric features. file a heterogeneous graph with categorical and numeric features.
""" """
import os
import argparse import argparse
import dgl import os
import pickle
import pandas as pd import pandas as pd
import scipy.sparse as ssp import scipy.sparse as ssp
import pickle
from data_utils import *
from builder import PandasGraphBuilder from builder import PandasGraphBuilder
from data_utils import *
import dgl
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('directory', type=str) parser.add_argument("directory", type=str)
parser.add_argument('out_directory', type=str) parser.add_argument("out_directory", type=str)
args = parser.parse_args() args = parser.parse_args()
directory = args.directory directory = args.directory
out_directory = args.out_directory out_directory = args.out_directory
os.makedirs(out_directory, exist_ok=True) os.makedirs(out_directory, exist_ok=True)
data = pd.read_csv(os.path.join(directory, 'context_content_features.csv')) data = pd.read_csv(os.path.join(directory, "context_content_features.csv"))
track_feature_cols = list(data.columns[1:13]) track_feature_cols = list(data.columns[1:13])
data = data[['user_id', 'track_id', 'created_at'] + track_feature_cols].dropna() data = data[
["user_id", "track_id", "created_at"] + track_feature_cols
].dropna()
users = data[['user_id']].drop_duplicates() users = data[["user_id"]].drop_duplicates()
tracks = data[['track_id'] + track_feature_cols].drop_duplicates() tracks = data[["track_id"] + track_feature_cols].drop_duplicates()
assert tracks['track_id'].value_counts().max() == 1 assert tracks["track_id"].value_counts().max() == 1
tracks = tracks.astype({'mode': 'int64', 'key': 'int64', 'artist_id': 'category'}) tracks = tracks.astype(
events = data[['user_id', 'track_id', 'created_at']] {"mode": "int64", "key": "int64", "artist_id": "category"}
events['created_at'] = events['created_at'].values.astype('datetime64[s]').astype('int64') )
events = data[["user_id", "track_id", "created_at"]]
events["created_at"] = (
events["created_at"].values.astype("datetime64[s]").astype("int64")
)
graph_builder = PandasGraphBuilder() graph_builder = PandasGraphBuilder()
graph_builder.add_entities(users, 'user_id', 'user') graph_builder.add_entities(users, "user_id", "user")
graph_builder.add_entities(tracks, 'track_id', 'track') graph_builder.add_entities(tracks, "track_id", "track")
graph_builder.add_binary_relations(events, 'user_id', 'track_id', 'listened') graph_builder.add_binary_relations(
graph_builder.add_binary_relations(events, 'track_id', 'user_id', 'listened-by') events, "user_id", "track_id", "listened"
)
graph_builder.add_binary_relations(
events, "track_id", "user_id", "listened-by"
)
g = graph_builder.build() g = graph_builder.build()
float_cols = [] float_cols = []
for col in tracks.columns: for col in tracks.columns:
if col == 'track_id': if col == "track_id":
continue continue
elif col == 'artist_id': elif col == "artist_id":
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].cat.codes.values) g.nodes["track"].data[col] = torch.LongTensor(
elif tracks.dtypes[col] == 'float64': tracks[col].cat.codes.values
)
elif tracks.dtypes[col] == "float64":
float_cols.append(col) float_cols.append(col)
else: else:
g.nodes['track'].data[col] = torch.LongTensor(tracks[col].values) g.nodes["track"].data[col] = torch.LongTensor(tracks[col].values)
g.nodes['track'].data['song_features'] = torch.FloatTensor(linear_normalize(tracks[float_cols].values)) g.nodes["track"].data["song_features"] = torch.FloatTensor(
g.edges['listened'].data['created_at'] = torch.LongTensor(events['created_at'].values) linear_normalize(tracks[float_cols].values)
g.edges['listened-by'].data['created_at'] = torch.LongTensor(events['created_at'].values) )
g.edges["listened"].data["created_at"] = torch.LongTensor(
events["created_at"].values
)
g.edges["listened-by"].data["created_at"] = torch.LongTensor(
events["created_at"].values
)
n_edges = g.num_edges('listened') n_edges = g.num_edges("listened")
train_indices, val_indices, test_indices = train_test_split_by_time(events, 'created_at', 'user_id') train_indices, val_indices, test_indices = train_test_split_by_time(
train_g = build_train_graph(g, train_indices, 'user', 'track', 'listened', 'listened-by') events, "created_at", "user_id"
assert train_g.out_degrees(etype='listened').min() > 0 )
train_g = build_train_graph(
g, train_indices, "user", "track", "listened", "listened-by"
)
assert train_g.out_degrees(etype="listened").min() > 0
val_matrix, test_matrix = build_val_test_matrix( val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, 'user', 'track', 'listened') g, val_indices, test_indices, "user", "track", "listened"
)
dgl.save_graphs(os.path.join(out_directory, 'train_g.bin'), train_g) dgl.save_graphs(os.path.join(out_directory, "train_g.bin"), train_g)
dataset = { dataset = {
'val-matrix': val_matrix, "val-matrix": val_matrix,
'test-matrix': test_matrix, "test-matrix": test_matrix,
'item-texts': {}, "item-texts": {},
'item-images': None, "item-images": None,
'user-type': 'user', "user-type": "user",
'item-type': 'track', "item-type": "track",
'user-to-item-type': 'listened', "user-to-item-type": "listened",
'item-to-user-type': 'listened-by', "item-to-user-type": "listened-by",
'timestamp-edge-column': 'created_at'} "timestamp-edge-column": "created_at",
}
with open(os.path.join(out_directory, 'data.pkl'), 'wb') as f: with open(os.path.join(out_directory, "data.pkl"), "wb") as f:
pickle.dump(dataset, f) pickle.dump(dataset, f)
import numpy as np import numpy as np
import dgl
import torch import torch
from torch.utils.data import IterableDataset, DataLoader from torch.utils.data import DataLoader, IterableDataset
from torchtext.data.functional import numericalize_tokens_from_iterator from torchtext.data.functional import numericalize_tokens_from_iterator
import dgl
def padding(array, yy, val): def padding(array, yy, val):
""" """
:param array: torch tensor array :param array: torch tensor array
...@@ -15,7 +17,10 @@ def padding(array, yy, val): ...@@ -15,7 +17,10 @@ def padding(array, yy, val):
b = 0 b = 0
bb = yy - b - w bb = yy - b - w
return torch.nn.functional.pad(array, pad=(b, bb), mode='constant', value=val) return torch.nn.functional.pad(
array, pad=(b, bb), mode="constant", value=val
)
def compact_and_copy(frontier, seeds): def compact_and_copy(frontier, seeds):
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
...@@ -25,6 +30,7 @@ def compact_and_copy(frontier, seeds): ...@@ -25,6 +30,7 @@ def compact_and_copy(frontier, seeds):
block.edata[col] = data[block.edata[dgl.EID]] block.edata[col] = data[block.edata[dgl.EID]]
return block return block
class ItemToItemBatchSampler(IterableDataset): class ItemToItemBatchSampler(IterableDataset):
def __init__(self, g, user_type, item_type, batch_size): def __init__(self, g, user_type, item_type, batch_size):
self.g = g self.g = g
...@@ -36,42 +42,69 @@ class ItemToItemBatchSampler(IterableDataset): ...@@ -36,42 +42,69 @@ class ItemToItemBatchSampler(IterableDataset):
def __iter__(self): def __iter__(self):
while True: while True:
heads = torch.randint(0, self.g.num_nodes(self.item_type), (self.batch_size,)) heads = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
tails = dgl.sampling.random_walk( tails = dgl.sampling.random_walk(
self.g, self.g,
heads, heads,
metapath=[self.item_to_user_etype, self.user_to_item_etype])[0][:, 2] metapath=[self.item_to_user_etype, self.user_to_item_etype],
neg_tails = torch.randint(0, self.g.num_nodes(self.item_type), (self.batch_size,)) )[0][:, 2]
neg_tails = torch.randint(
0, self.g.num_nodes(self.item_type), (self.batch_size,)
)
mask = (tails != -1) mask = tails != -1
yield heads[mask], tails[mask], neg_tails[mask] yield heads[mask], tails[mask], neg_tails[mask]
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, user_type, item_type, random_walk_length, random_walk_restart_prob, def __init__(
num_random_walks, num_neighbors, num_layers): self,
g,
user_type,
item_type,
random_walk_length,
random_walk_restart_prob,
num_random_walks,
num_neighbors,
num_layers,
):
self.g = g self.g = g
self.user_type = user_type self.user_type = user_type
self.item_type = item_type self.item_type = item_type
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0] self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0] self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.samplers = [ self.samplers = [
dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length, dgl.sampling.PinSAGESampler(
random_walk_restart_prob, num_random_walks, num_neighbors) g,
for _ in range(num_layers)] item_type,
user_type,
random_walk_length,
random_walk_restart_prob,
num_random_walks,
num_neighbors,
)
for _ in range(num_layers)
]
def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None): def sample_blocks(self, seeds, heads=None, tails=None, neg_tails=None):
blocks = [] blocks = []
for sampler in self.samplers: for sampler in self.samplers:
frontier = sampler(seeds) frontier = sampler(seeds)
if heads is not None: if heads is not None:
eids = frontier.edge_ids(torch.cat([heads, heads]), torch.cat([tails, neg_tails]), return_uv=True)[2] eids = frontier.edge_ids(
torch.cat([heads, heads]),
torch.cat([tails, neg_tails]),
return_uv=True,
)[2]
if len(eids) > 0: if len(eids) > 0:
old_frontier = frontier old_frontier = frontier
frontier = dgl.remove_edges(old_frontier, eids) frontier = dgl.remove_edges(old_frontier, eids)
#print(old_frontier) # print(old_frontier)
#print(frontier) # print(frontier)
#print(frontier.edata['weights']) # print(frontier.edata['weights'])
#frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]] # frontier.edata['weights'] = old_frontier.edata['weights'][frontier.edata[dgl.EID]]
block = compact_and_copy(frontier, seeds) block = compact_and_copy(frontier, seeds)
seeds = block.srcdata[dgl.NID] seeds = block.srcdata[dgl.NID]
blocks.insert(0, block) blocks.insert(0, block)
...@@ -81,17 +114,18 @@ class NeighborSampler(object): ...@@ -81,17 +114,18 @@ class NeighborSampler(object):
# Create a graph with positive connections only and another graph with negative # Create a graph with positive connections only and another graph with negative
# connections only. # connections only.
pos_graph = dgl.graph( pos_graph = dgl.graph(
(heads, tails), (heads, tails), num_nodes=self.g.num_nodes(self.item_type)
num_nodes=self.g.num_nodes(self.item_type)) )
neg_graph = dgl.graph( neg_graph = dgl.graph(
(heads, neg_tails), (heads, neg_tails), num_nodes=self.g.num_nodes(self.item_type)
num_nodes=self.g.num_nodes(self.item_type)) )
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph]) pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID] seeds = pos_graph.ndata[dgl.NID]
blocks = self.sample_blocks(seeds, heads, tails, neg_tails) blocks = self.sample_blocks(seeds, heads, tails, neg_tails)
return pos_graph, neg_graph, blocks return pos_graph, neg_graph, blocks
def assign_simple_node_features(ndata, g, ntype, assign_id=False): def assign_simple_node_features(ndata, g, ntype, assign_id=False):
""" """
Copies data to the given block from the corresponding nodes in the original graph. Copies data to the given block from the corresponding nodes in the original graph.
...@@ -102,6 +136,7 @@ def assign_simple_node_features(ndata, g, ntype, assign_id=False): ...@@ -102,6 +136,7 @@ def assign_simple_node_features(ndata, g, ntype, assign_id=False):
induced_nodes = ndata[dgl.NID] induced_nodes = ndata[dgl.NID]
ndata[col] = g.nodes[ntype].data[col][induced_nodes] ndata[col] = g.nodes[ntype].data[col][induced_nodes]
def assign_textual_node_features(ndata, textset, ntype): def assign_textual_node_features(ndata, textset, ntype):
""" """
Assigns numericalized tokens from a torchtext dataset to given block. Assigns numericalized tokens from a torchtext dataset to given block.
...@@ -127,10 +162,10 @@ def assign_textual_node_features(ndata, textset, ntype): ...@@ -127,10 +162,10 @@ def assign_textual_node_features(ndata, textset, ntype):
for field_name, field in textset.items(): for field_name, field in textset.items():
textlist, vocab, pad_var, batch_first = field textlist, vocab, pad_var, batch_first = field
examples = [textlist[i] for i in node_ids] examples = [textlist[i] for i in node_ids]
ids_iter = numericalize_tokens_from_iterator(vocab, examples) ids_iter = numericalize_tokens_from_iterator(vocab, examples)
maxsize = max([len(textlist[i]) for i in node_ids]) maxsize = max([len(textlist[i]) for i in node_ids])
ids = next(ids_iter) ids = next(ids_iter)
x = torch.asarray([num for num in ids]) x = torch.asarray([num for num in ids])
...@@ -141,14 +176,15 @@ def assign_textual_node_features(ndata, textset, ntype): ...@@ -141,14 +176,15 @@ def assign_textual_node_features(ndata, textset, ntype):
x = torch.asarray([num for num in ids]) x = torch.asarray([num for num in ids])
l = torch.tensor([len(x)]) l = torch.tensor([len(x)])
y = padding(x, maxsize, pad_var) y = padding(x, maxsize, pad_var)
tokens = torch.vstack((tokens,y)) tokens = torch.vstack((tokens, y))
lengths = torch.cat((lengths, l)) lengths = torch.cat((lengths, l))
if not batch_first: if not batch_first:
tokens = tokens.t() tokens = tokens.t()
ndata[field_name] = tokens ndata[field_name] = tokens
ndata[field_name + '__len'] = lengths ndata[field_name + "__len"] = lengths
def assign_features_to_blocks(blocks, g, textset, ntype): def assign_features_to_blocks(blocks, g, textset, ntype):
# For the first block (which is closest to the input), copy the features from # For the first block (which is closest to the input), copy the features from
...@@ -158,6 +194,7 @@ def assign_features_to_blocks(blocks, g, textset, ntype): ...@@ -158,6 +194,7 @@ def assign_features_to_blocks(blocks, g, textset, ntype):
assign_simple_node_features(blocks[-1].dstdata, g, ntype) assign_simple_node_features(blocks[-1].dstdata, g, ntype)
assign_textual_node_features(blocks[-1].dstdata, textset, ntype) assign_textual_node_features(blocks[-1].dstdata, textset, ntype)
class PinSAGECollator(object): class PinSAGECollator(object):
def __init__(self, sampler, g, ntype, textset): def __init__(self, sampler, g, ntype, textset):
self.sampler = sampler self.sampler = sampler
...@@ -168,7 +205,9 @@ class PinSAGECollator(object): ...@@ -168,7 +205,9 @@ class PinSAGECollator(object):
def collate_train(self, batches): def collate_train(self, batches):
heads, tails, neg_tails = batches[0] heads, tails, neg_tails = batches[0]
# Construct multilayer neighborhood via PinSAGE... # Construct multilayer neighborhood via PinSAGE...
pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(heads, tails, neg_tails) pos_graph, neg_graph, blocks = self.sampler.sample_from_item_pairs(
heads, tails, neg_tails
)
assign_features_to_blocks(blocks, self.g, self.textset, self.ntype) assign_features_to_blocks(blocks, self.g, self.textset, self.ntype)
return pos_graph, neg_graph, blocks return pos_graph, neg_graph, blocks
......
import argparse
import os
import urllib
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader import tqdm
from modelnet import ModelNet
from model import Model, compute_loss from model import Model, compute_loss
from dgl.data.utils import download, get_download_dir from modelnet import ModelNet
from torch.utils.data import DataLoader
from functools import partial from dgl.data.utils import download, get_download_dir
import tqdm
import urllib
import os
import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=100) parser.add_argument("--num-epochs", type=int, default=100)
parser.add_argument('--num-workers', type=int, default=0) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40-sampled-2048.h5' data_filename = "modelnet40-sampled-2048.h5"
local_path = args.dataset_path or os.path.join(get_download_dir(), data_filename) local_path = args.dataset_path or os.path.join(
get_download_dir(), data_filename
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5', local_path) download(
"https://data.dgl.ai/dataset/modelnet40-sampled-2048.h5", local_path
)
CustomDataLoader = partial( CustomDataLoader = partial(
DataLoader, DataLoader,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(model, opt, scheduler, train_loader, dev): def train(model, opt, scheduler, train_loader, dev):
scheduler.step() scheduler.step()
...@@ -65,11 +72,15 @@ def train(model, opt, scheduler, train_loader, dev): ...@@ -65,11 +72,15 @@ def train(model, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'Loss': '%.5f' % loss, {
'AvgLoss': '%.5f' % (total_loss / num_batches), "Loss": "%.5f" % loss,
'Acc': '%.5f' % (correct / num_examples), "AvgLoss": "%.5f" % (total_loss / num_batches),
'AvgAcc': '%.5f' % (total_correct / count)}) "Acc": "%.5f" % (correct / num_examples),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
def evaluate(model, test_loader, dev): def evaluate(model, test_loader, dev):
model.eval() model.eval()
...@@ -89,9 +100,12 @@ def evaluate(model, test_loader, dev): ...@@ -89,9 +100,12 @@ def evaluate(model, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix(
'Acc': '%.5f' % (correct / num_examples), {
'AvgAcc': '%.5f' % (total_correct / count)}) "Acc": "%.5f" % (correct / num_examples),
"AvgAcc": "%.5f" % (total_correct / count),
}
)
return total_correct / count return total_correct / count
...@@ -105,7 +119,9 @@ if args.load_model_path: ...@@ -105,7 +119,9 @@ if args.load_model_path:
opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, args.num_epochs, eta_min=0.001) scheduler = optim.lr_scheduler.CosineAnnealingLR(
opt, args.num_epochs, eta_min=0.001
)
modelnet = ModelNet(local_path, 1024) modelnet = ModelNet(local_path, 1024)
...@@ -117,7 +133,7 @@ best_valid_acc = 0 ...@@ -117,7 +133,7 @@ best_valid_acc = 0
best_test_acc = 0 best_test_acc = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
print('Epoch #%d Validating' % epoch) print("Epoch #%d Validating" % epoch)
valid_acc = evaluate(model, valid_loader, dev) valid_acc = evaluate(model, valid_loader, dev)
test_acc = evaluate(model, test_loader, dev) test_acc = evaluate(model, test_loader, dev)
if valid_acc > best_valid_acc: if valid_acc > best_valid_acc:
...@@ -125,7 +141,9 @@ for epoch in range(args.num_epochs): ...@@ -125,7 +141,9 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(model.state_dict(), args.save_model_path) torch.save(model.state_dict(), args.save_model_path)
print('Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)' % ( print(
valid_acc, best_valid_acc, test_acc, best_test_acc)) "Current validation acc: %.5f (best: %.5f), test acc: %.5f (best: %.5f)"
% (valid_acc, best_valid_acc, test_acc, best_test_acc)
)
train(model, opt, scheduler, train_loader, dev) train(model, opt, scheduler, train_loader, dev)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import KNNGraph, EdgeConv
from dgl.nn.pytorch import EdgeConv, KNNGraph
class Model(nn.Module): class Model(nn.Module):
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3, def __init__(
dropout_prob=0.5): self,
k,
feature_dims,
emb_dims,
output_classes,
input_dims=3,
dropout_prob=0.5,
):
super(Model, self).__init__() super(Model, self).__init__()
self.nng = KNNGraph(k) self.nng = KNNGraph(k)
...@@ -13,10 +22,13 @@ class Model(nn.Module): ...@@ -13,10 +22,13 @@ class Model(nn.Module):
self.num_layers = len(feature_dims) self.num_layers = len(feature_dims)
for i in range(self.num_layers): for i in range(self.num_layers):
self.conv.append(EdgeConv( self.conv.append(
feature_dims[i - 1] if i > 0 else input_dims, EdgeConv(
feature_dims[i], feature_dims[i - 1] if i > 0 else input_dims,
batch_norm=True)) feature_dims[i],
batch_norm=True,
)
)
self.proj = nn.Linear(sum(feature_dims), emb_dims[0]) self.proj = nn.Linear(sum(feature_dims), emb_dims[0])
...@@ -26,10 +38,13 @@ class Model(nn.Module): ...@@ -26,10 +38,13 @@ class Model(nn.Module):
self.num_embs = len(emb_dims) - 1 self.num_embs = len(emb_dims) - 1
for i in range(1, self.num_embs + 1): for i in range(1, self.num_embs + 1):
self.embs.append(nn.Linear( self.embs.append(
# * 2 because of concatenation of max- and mean-pooling nn.Linear(
emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2), # * 2 because of concatenation of max- and mean-pooling
emb_dims[i])) emb_dims[i - 1] if i > 1 else (emb_dims[i - 1] * 2),
emb_dims[i],
)
)
self.bn_embs.append(nn.BatchNorm1d(emb_dims[i])) self.bn_embs.append(nn.BatchNorm1d(emb_dims[i]))
self.dropouts.append(nn.Dropout(dropout_prob)) self.dropouts.append(nn.Dropout(dropout_prob))
......
import numpy as np import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
class ModelNet(object): class ModelNet(object):
def __init__(self, path, num_points): def __init__(self, path, num_points):
import h5py import h5py
self.f = h5py.File(path) self.f = h5py.File(path)
self.num_points = num_points self.num_points = num_points
self.n_train = self.f['train/data'].shape[0] self.n_train = self.f["train/data"].shape[0]
self.n_valid = int(self.n_train / 5) self.n_valid = int(self.n_train / 5)
self.n_train -= self.n_valid self.n_train -= self.n_valid
self.n_test = self.f['test/data'].shape[0] self.n_test = self.f["test/data"].shape[0]
def train(self): def train(self):
return ModelNetDataset(self, 'train') return ModelNetDataset(self, "train")
def valid(self): def valid(self):
return ModelNetDataset(self, 'valid') return ModelNetDataset(self, "valid")
def test(self): def test(self):
return ModelNetDataset(self, 'test') return ModelNetDataset(self, "test")
class ModelNetDataset(Dataset): class ModelNetDataset(Dataset):
def __init__(self, modelnet, mode): def __init__(self, modelnet, mode):
...@@ -27,29 +30,29 @@ class ModelNetDataset(Dataset): ...@@ -27,29 +30,29 @@ class ModelNetDataset(Dataset):
self.num_points = modelnet.num_points self.num_points = modelnet.num_points
self.mode = mode self.mode = mode
if mode == 'train': if mode == "train":
self.data = modelnet.f['train/data'][:modelnet.n_train] self.data = modelnet.f["train/data"][: modelnet.n_train]
self.label = modelnet.f['train/label'][:modelnet.n_train] self.label = modelnet.f["train/label"][: modelnet.n_train]
elif mode == 'valid': elif mode == "valid":
self.data = modelnet.f['train/data'][modelnet.n_train:] self.data = modelnet.f["train/data"][modelnet.n_train :]
self.label = modelnet.f['train/label'][modelnet.n_train:] self.label = modelnet.f["train/label"][modelnet.n_train :]
elif mode == 'test': elif mode == "test":
self.data = modelnet.f['test/data'].value self.data = modelnet.f["test/data"].value
self.label = modelnet.f['test/label'].value self.label = modelnet.f["test/label"].value
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2)): def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2)):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3]) xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[3])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3]) xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[3])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32') x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x return x
def __len__(self): def __len__(self):
return self.data.shape[0] return self.data.shape[0]
def __getitem__(self, i): def __getitem__(self, i):
x = self.data[i][:self.num_points] x = self.data[i][: self.num_points]
y = self.label[i] y = self.label[i]
if self.mode == 'train': if self.mode == "train":
x = self.translate(x) x = self.translate(x)
np.random.shuffle(x) np.random.shuffle(x)
return x, y return x, y
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
...@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint): ...@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
npoint: number of points from each cloud npoint: number of points from each cloud
split: which split of the data, 'train' or 'test' split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler fps: whether to sample points with farthest point sampler
...@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open( shape_ids["train"] = [
os.path.join(self.root, 'modelnet40_train.txt'))] line.rstrip()
shape_ids['test'] = [line.rstrip() for line in open( for line in open(os.path.join(self.root, "modelnet40_train.txt"))
os.path.join(self.root, 'modelnet40_test.txt'))] ]
shape_ids["test"] = [
assert (split == 'train' or split == 'test') line.rstrip()
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d' % (split, len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints, :] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import os, json, tqdm import json
import numpy as np import os
import dgl
from zipfile import ZipFile from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points self.num_points = num_points
...@@ -13,8 +18,13 @@ class ShapeNet(object): ...@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir() download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" data_filename = (
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal") "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path): if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename) local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path): if not os.path.exists(local_path):
...@@ -24,52 +34,72 @@ class ShapeNet(object): ...@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt" synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f: with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()] synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {} self.synset_dict = {}
for syn in synset: for syn in synset:
self.synset_dict[syn[1]] = syn[0] self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3], self.seg_classes = {
'Bag': [4, 5], "Airplane": [0, 1, 2, 3],
'Cap': [6, 7], "Bag": [4, 5],
'Car': [8, 9, 10, 11], "Cap": [6, 7],
'Chair': [12, 13, 14, 15], "Car": [8, 9, 10, 11],
'Earphone': [16, 17, 18], "Chair": [12, 13, 14, 15],
'Guitar': [19, 20, 21], "Earphone": [16, 17, 18],
'Knife': [22, 23], "Guitar": [19, 20, 21],
'Lamp': [24, 25, 26, 27], "Knife": [22, 23],
'Laptop': [28, 29], "Lamp": [24, 25, 26, 27],
'Motorbike': [30, 31, 32, 33, 34, 35], "Laptop": [28, 29],
'Mug': [36, 37], "Motorbike": [30, 31, 32, 33, 34, 35],
'Pistol': [38, 39, 40], "Mug": [36, 37],
'Rocket': [41, 42, 43], "Pistol": [38, 39, 40],
'Skateboard': [44, 45, 46], "Rocket": [41, 42, 43],
'Table': [47, 48, 49]} "Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
train_split_json = 'shuffled_train_file_list.json' }
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json' train_split_json = "shuffled_train_file_list.json"
split_path = os.path.join(data_path, 'train_test_split') val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f: with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read() tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f: with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read() tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f: with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read() tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self): def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self): def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self): def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self): def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset): class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True): def __init__(self, shapenet, mode, num_points, normal_channel=True):
...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset): ...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else: else:
self.dim = 6 self.dim = 6
if mode == 'train': if mode == "train":
self.file_list = shapenet.train_file_list self.file_list = shapenet.train_file_list
elif mode == 'valid': elif mode == "valid":
self.file_list = shapenet.val_file_list self.file_list = shapenet.val_file_list
elif mode == 'test': elif mode == "test":
self.file_list = shapenet.test_file_list self.file_list = shapenet.test_file_list
elif mode == 'trainval': elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list self.file_list = shapenet.train_file_list + shapenet.val_file_list
else: else:
raise "Not supported `mode`" raise "Not supported `mode`"
...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset): ...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = [] data_list = []
label_list = [] label_list = []
category_list = [] category_list = []
print('Loading data from split ' + self.mode) print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True): for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f: with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float) data = np.array(
data_list.append(data[:, 0:self.dim]) [t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int)) label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]]) category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list self.data = data_list
self.label = label_list self.label = label_list
self.category = category_list self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3): def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size]) xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size]) xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32') x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x return x
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __getitem__(self, i): def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True) inds = np.random.choice(
x = self.data[i][inds,:self.dim] self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds] y = self.label[i][inds]
cat = self.category[i] cat = self.category[i]
if self.mode == 'train': if self.mode == "train":
x = self.translate(x, size=self.dim) x = self.translate(x, size=self.dim)
x = x.astype(np.float) x = x.astype(np.float)
y = y.astype(np.int) y = y.astype(np.int)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.geometry import farthest_point_sampler from dgl.geometry import farthest_point_sampler
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
def square_distance(src, dst): def square_distance(src, dst):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
B, N, _ = src.shape B, N, _ = src.shape
_, M, _ = dst.shape _, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M) dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist return dist
def index_points(points, idx): def index_points(points, idx):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
device = points.device device = points.device
B = points.shape[0] B = points.shape[0]
view_shape = list(idx.shape) view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1) view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape) repeat_shape = list(idx.shape)
repeat_shape[0] = 1 repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to( batch_indices = (
device).view(view_shape).repeat(repeat_shape) torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :] new_points = points[batch_indices, idx, :]
return new_points return new_points
class KNearNeighbors(nn.Module): class KNearNeighbors(nn.Module):
''' """
Find the k nearest neighbors Find the k nearest neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNearNeighbors, self).__init__() super(KNearNeighbors, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, pos, centroids): def forward(self, pos, centroids):
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
''' """
center_pos = index_points(pos, centroids) center_pos = index_points(pos, centroids)
sqrdists = square_distance(center_pos, pos) sqrdists = square_distance(center_pos, pos)
group_idx = sqrdists.argsort(dim=-1)[:, :, :self.n_neighbor] group_idx = sqrdists.argsort(dim=-1)[:, :, : self.n_neighbor]
return group_idx return group_idx
class KNNGraphBuilder(nn.Module): class KNNGraphBuilder(nn.Module):
''' """
Build NN graph Build NN graph
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNNGraphBuilder, self).__init__() super(KNNGraphBuilder, self).__init__()
...@@ -76,59 +81,68 @@ class KNNGraphBuilder(nn.Module): ...@@ -76,59 +81,68 @@ class KNNGraphBuilder(nn.Module):
center = torch.zeros((N)).to(dev) center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1 center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1) src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, min(self.n_neighbor, dst = (
src.shape[0] // centroids.shape[1])).view(-1) centroids[i]
.view(-1, 1)
.repeat(
1, min(self.n_neighbor, src.shape[0] // centroids.shape[1])
)
.view(-1)
)
unified = torch.cat([src, dst]) unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True) uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[:src.shape[0]] src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0]:] dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx)) g = dgl.graph((src_idx, dst_idx))
g.ndata['pos'] = pos[i][uniq] g.ndata["pos"] = pos[i][uniq]
g.ndata['center'] = center[uniq] g.ndata["center"] = center[uniq]
if feat is not None: if feat is not None:
g.ndata['feat'] = feat[i][uniq] g.ndata["feat"] = feat[i][uniq]
glist.append(g) glist.append(g)
bg = dgl.batch(glist) bg = dgl.batch(glist)
return bg return bg
class KNNMessage(nn.Module): class KNNMessage(nn.Module):
''' """
Compute the input feature from neighbors Compute the input feature from neighbors
''' """
def __init__(self, n_neighbor): def __init__(self, n_neighbor):
super(KNNMessage, self).__init__() super(KNNMessage, self).__init__()
self.n_neighbor = n_neighbor self.n_neighbor = n_neighbor
def forward(self, edges): def forward(self, edges):
norm = edges.src['feat'] - edges.dst['feat'] norm = edges.src["feat"] - edges.dst["feat"]
if 'feat' in edges.src: if "feat" in edges.src:
res = torch.cat([norm, edges.src['feat']], 1) res = torch.cat([norm, edges.src["feat"]], 1)
else: else:
res = norm res = norm
return {'agg_feat': res} return {"agg_feat": res}
class KNNConv(nn.Module): class KNNConv(nn.Module):
''' """
Feature aggregation Feature aggregation
''' """
def __init__(self, sizes): def __init__(self, sizes):
super(KNNConv, self).__init__() super(KNNConv, self).__init__()
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.bn = nn.ModuleList() self.bn = nn.ModuleList()
for i in range(1, len(sizes)): for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i-1], sizes[i], 1)) self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i])) self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes): def forward(self, nodes):
shape = nodes.mailbox['agg_feat'].shape shape = nodes.mailbox["agg_feat"].shape
h = nodes.mailbox['agg_feat'].view( h = (
shape[0], -1, shape[1], shape[2]).permute(0, 3, 2, 1) nodes.mailbox["agg_feat"]
.view(shape[0], -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn): for conv, bn in zip(self.conv, self.bn):
h = conv(h) h = conv(h)
h = bn(h) h = bn(h)
...@@ -136,7 +150,7 @@ class KNNConv(nn.Module): ...@@ -136,7 +150,7 @@ class KNNConv(nn.Module):
h = torch.max(h, 2)[0] h = torch.max(h, 2)[0]
feat_dim = h.shape[1] feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim) h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {'new_feat': h} return {"new_feat": h}
class TransitionDown(nn.Module): class TransitionDown(nn.Module):
...@@ -156,10 +170,9 @@ class TransitionDown(nn.Module): ...@@ -156,10 +170,9 @@ class TransitionDown(nn.Module):
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
mask = g.ndata['center'] == 1 mask = g.ndata["center"] == 1
pos_dim = g.ndata['pos'].shape[-1] pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata['new_feat'].shape[-1] feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata['pos'][mask].view(batch_size, -1, pos_dim) pos_res = g.ndata["pos"][mask].view(batch_size, -1, pos_dim)
feat_res = g.ndata['new_feat'][mask].view( feat_res = g.ndata["new_feat"][mask].view(batch_size, -1, feat_dim)
batch_size, -1, feat_dim)
return pos_res, feat_res return pos_res, feat_res
import torch import torch
from torch import nn
from helper import TransitionDown from helper import TransitionDown
from torch import nn
''' """
Part of the code are adapted from Part of the code are adapted from
https://github.com/MenghaoGuo/PCT https://github.com/MenghaoGuo/PCT
''' """
class PCTPositionEmbedding(nn.Module): class PCTPositionEmbedding(nn.Module):
def __init__(self, channels=256): def __init__(self, channels=256):
...@@ -98,15 +98,21 @@ class PointTransformerCLS(nn.Module): ...@@ -98,15 +98,21 @@ class PointTransformerCLS(nn.Module):
self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm1d(64) self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(64)
self.g_op0 = TransitionDown(in_channels=128, out_channels=128, n_neighbor=32) self.g_op0 = TransitionDown(
self.g_op1 = TransitionDown(in_channels=256, out_channels=256, n_neighbor=32) in_channels=128, out_channels=128, n_neighbor=32
)
self.g_op1 = TransitionDown(
in_channels=256, out_channels=256, n_neighbor=32
)
self.pt_last = PCTPositionEmbedding() self.pt_last = PCTPositionEmbedding()
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), self.conv_fuse = nn.Sequential(
nn.BatchNorm1d(1024), nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
nn.LeakyReLU(negative_slope=0.2)) nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2),
)
self.linear1 = nn.Linear(1024, 512, bias=False) self.linear1 = nn.Linear(1024, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512) self.bn6 = nn.BatchNorm1d(512)
...@@ -159,13 +165,17 @@ class PointTransformerSeg(nn.Module): ...@@ -159,13 +165,17 @@ class PointTransformerSeg(nn.Module):
self.sa3 = SALayerSeg(128) self.sa3 = SALayerSeg(128)
self.sa4 = SALayerSeg(128) self.sa4 = SALayerSeg(128)
self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), self.conv_fuse = nn.Sequential(
nn.BatchNorm1d(1024), nn.Conv1d(512, 1024, kernel_size=1, bias=False),
nn.LeakyReLU(negative_slope=0.2)) nn.BatchNorm1d(1024),
nn.LeakyReLU(negative_slope=0.2),
)
self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False), self.label_conv = nn.Sequential(
nn.BatchNorm1d(64), nn.Conv1d(16, 64, kernel_size=1, bias=False),
nn.LeakyReLU(negative_slope=0.2)) nn.BatchNorm1d(64),
nn.LeakyReLU(negative_slope=0.2),
)
self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1) self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)
self.dp1 = nn.Dropout(0.5) self.dp1 = nn.Dropout(0.5)
...@@ -189,13 +199,12 @@ class PointTransformerSeg(nn.Module): ...@@ -189,13 +199,12 @@ class PointTransformerSeg(nn.Module):
x = self.conv_fuse(x) x = self.conv_fuse(x)
x_max, _ = torch.max(x, 2) x_max, _ = torch.max(x, 2)
x_avg = torch.mean(x, 2) x_avg = torch.mean(x, 2)
x_max_feature = x_max.view( x_max_feature = x_max.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
batch_size, -1).unsqueeze(-1).repeat(1, 1, N) x_avg_feature = x_avg.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
x_avg_feature = x_avg.view(
batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
cls_label_feature = self.label_conv(cls_label).repeat(1, 1, N) cls_label_feature = self.label_conv(cls_label).repeat(1, 1, N)
x_global_feature = torch.cat( x_global_feature = torch.cat(
(x_max_feature, x_avg_feature, cls_label_feature), 1) (x_max_feature, x_avg_feature, cls_label_feature), 1
)
x = torch.cat((x, x_global_feature), 1) x = torch.cat((x, x_global_feature), 1)
x = self.relu(self.bns1(self.convs1(x))) x = self.relu(self.bns1(self.convs1(x)))
x = self.dp1(x) x = self.dp1(x)
......
''' """
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/provider.py
''' """
import numpy as np import numpy as np
def normalize_data(batch_data): def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin, """Normalize the batch data, use coordinates of the block centered at origin,
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C)) normal_data = np.zeros((B, N, C))
...@@ -16,237 +17,296 @@ def normalize_data(batch_data): ...@@ -16,237 +17,296 @@ def normalize_data(batch_data):
pc = batch_data[b] pc = batch_data[b]
centroid = np.mean(pc, axis=0) centroid = np.mean(pc, axis=0)
pc = pc - centroid pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m pc = pc / m
normal_data[b] = pc normal_data[b] = pc
return normal_data return normal_data
def shuffle_data(data, labels): def shuffle_data(data, labels):
""" Shuffle data and labels. """Shuffle data and labels.
Input: Input:
data: B,N,... numpy array data: B,N,... numpy array
label: B,... numpy array label: B,... numpy array
Return: Return:
shuffled data, label and shuffle indices shuffled data, label and shuffle indices
""" """
idx = np.arange(len(labels)) idx = np.arange(len(labels))
np.random.shuffle(idx) np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data): def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior. """Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch. Use the same shuffling idx for the entire batch.
Input: Input:
BxNxC array BxNxC array
Output: Output:
BxNxC array BxNxC array
""" """
idx = np.arange(batch_data.shape[1]) idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx) np.random.shuffle(idx)
return batch_data[:,idx,:] return batch_data[:, idx, :]
def rotate_point_cloud(batch_data): def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_z(batch_data): def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset """Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction rotation is per shape based along up direction
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0], rotation_matrix = np.array(
[-sinval, cosval, 0], [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]]
[0, 0, 1]]) )
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, ...] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal): def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud. """Randomly rotate XYZ, normal point cloud.
Input: Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output: Output:
B,N,6, rotated XYZ, normal point cloud B,N,6, rotated XYZ, normal point cloud
''' """
for k in range(batch_xyz_normal.shape[0]): for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_xyz_normal[k,:,0:3] shape_pc = batch_xyz_normal[k, :, 0:3]
shape_normal = batch_xyz_normal[k,:,3:6] shape_normal = batch_xyz_normal[k, :, 3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) batch_xyz_normal[k, :, 0:3] = np.dot(
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
batch_xyz_normal[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return batch_xyz_normal return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations def rotate_perturbation_point_cloud_with_normal(
Input: batch_data, angle_sigma=0.06, angle_clip=0.18
BxNx6 array, original batch of point clouds and point normals ):
Return: """Randomly perturb the point clouds by small rotations
BxNx3 array, rotated batch of point clouds Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
shape_pc = batch_data[k,:,0:3] [
shape_normal = batch_data[k,:,3:6] [np.cos(angles[1]), 0, np.sin(angles[1])],
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) [0, 1, 0],
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) [-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k, :, 3:6]
rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle): def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, rotated batch of point clouds BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
shape_pc.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle. """Rotate the point cloud along up direction with certain angle.
Input: Input:
BxNx6 array, original batch of point clouds with normal BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation scalar, angle of rotation
Return: Return:
BxNx6 array, rotated batch of point clouds iwth normal BxNx6 array, rotated batch of point clouds iwth normal
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi # rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle) cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle) sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval], rotation_matrix = np.array(
[0, 1, 0], [[cosval, 0, sinval], [0, 1, 0], [-sinval, 0, cosval]]
[-sinval, 0, cosval]]) )
shape_pc = batch_data[k,:,0:3] shape_pc = batch_data[k, :, 0:3]
shape_normal = batch_data[k,:,3:6] shape_normal = batch_data[k, :, 3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) rotated_data[k, :, 0:3] = np.dot(
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) shape_pc.reshape((-1, 3)), rotation_matrix
)
rotated_data[k, :, 3:6] = np.dot(
shape_normal.reshape((-1, 3)), rotation_matrix
)
return rotated_data return rotated_data
def rotate_perturbation_point_cloud(
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): batch_data, angle_sigma=0.06, angle_clip=0.18
""" Randomly perturb the point clouds by small rotations ):
Input: """Randomly perturb the point clouds by small rotations
BxNx3 array, original batch of point clouds Input:
Return: BxNx3 array, original batch of point clouds
BxNx3 array, rotated batch of point clouds Return:
BxNx3 array, rotated batch of point clouds
""" """
rotated_data = np.zeros(batch_data.shape, dtype=np.float32) rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]): for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) angles = np.clip(
Rx = np.array([[1,0,0], angle_sigma * np.random.randn(3), -angle_clip, angle_clip
[0,np.cos(angles[0]),-np.sin(angles[0])], )
[0,np.sin(angles[0]),np.cos(angles[0])]]) Rx = np.array(
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], [
[0,1,0], [1, 0, 0],
[-np.sin(angles[1]),0,np.cos(angles[1])]]) [0, np.cos(angles[0]), -np.sin(angles[0])],
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], [0, np.sin(angles[0]), np.cos(angles[0])],
[np.sin(angles[2]),np.cos(angles[2]),0], ]
[0,0,1]]) )
R = np.dot(Rz, np.dot(Ry,Rx)) Ry = np.array(
[
[np.cos(angles[1]), 0, np.sin(angles[1])],
[0, 1, 0],
[-np.sin(angles[1]), 0, np.cos(angles[1])],
]
)
Rz = np.array(
[
[np.cos(angles[2]), -np.sin(angles[2]), 0],
[np.sin(angles[2]), np.cos(angles[2]), 0],
[0, 0, 1],
]
)
R = np.dot(Rz, np.dot(Ry, Rx))
shape_pc = batch_data[k, ...] shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point. """Randomly jitter points. jittering is per point.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, jittered batch of point clouds BxNx3 array, jittered batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
assert(clip > 0) assert clip > 0
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
jittered_data += batch_data jittered_data += batch_data
return jittered_data return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1): def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud. """Randomly shift point cloud. Shift is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, shifted batch of point clouds BxNx3 array, shifted batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3)) shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:] batch_data[batch_index, :, :] += shifts[batch_index, :]
return batch_data return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud. """Randomly scale the point cloud. Scale is per point cloud.
Input: Input:
BxNx3 array, original batch of point clouds BxNx3 array, original batch of point clouds
Return: Return:
BxNx3 array, scaled batch of point clouds BxNx3 array, scaled batch of point clouds
""" """
B, N, C = batch_data.shape B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B) scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B): for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index] batch_data[batch_index, :, :] *= scales[batch_index]
return batch_data return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875): def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 ''' """batch_pc: BxNx3"""
for b in range(batch_pc.shape[0]): for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] drop_idx = np.where(
if len(drop_idx)>0: np.random.random((batch_pc.shape[1])) <= dropout_ratio
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 # not need )[0]
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point if len(drop_idx) > 0:
dropout_ratio = (
np.random.random() * max_dropout_ratio
) # 0~0.875 # not need
batch_pc[b, drop_idx, :] = batch_pc[
b, 0, :
] # set to the first point
return batch_pc return batch_pc
from pct import PointTransformerCLS
from ModelNetDataLoader import ModelNetDataLoader
import provider
import argparse import argparse
import os import os
import tqdm import time
from functools import partial from functools import partial
from dgl.data.utils import download, get_download_dir
from torch.utils.data import DataLoader import provider
import torch.nn as nn
import torch import torch
import time import torch.nn as nn
import tqdm
from ModelNetDataLoader import ModelNetDataLoader
from pct import PointTransformerCLS
from torch.utils.data import DataLoader
from dgl.data.utils import download, get_download_dir
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=250) parser.add_argument("--num-epochs", type=int, default=250)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=32) parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
batch_size = args.batch_size batch_size = args.batch_size
data_filename = 'modelnet40_normal_resampled.zip' data_filename = "modelnet40_normal_resampled.zip"
download_path = os.path.join(get_download_dir(), data_filename) download_path = os.path.join(get_download_dir(), data_filename)
local_path = args.dataset_path or os.path.join( local_path = args.dataset_path or os.path.join(
get_download_dir(), 'modelnet40_normal_resampled') get_download_dir(), "modelnet40_normal_resampled"
)
if not os.path.exists(local_path): if not os.path.exists(local_path):
download('https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip', download(
download_path, verify_ssl=False) "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip",
download_path,
verify_ssl=False,
)
from zipfile import ZipFile from zipfile import ZipFile
with ZipFile(download_path) as z: with ZipFile(download_path) as z:
z.extractall(path=get_download_dir()) z.extractall(path=get_download_dir())
...@@ -42,7 +49,8 @@ CustomDataLoader = partial( ...@@ -42,7 +49,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -59,8 +67,7 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -59,8 +67,7 @@ def train(net, opt, scheduler, train_loader, dev):
for data, label in tq: for data, label in tq:
data = data.data.numpy() data = data.data.numpy()
data = provider.random_point_dropout(data) data = provider.random_point_dropout(data)
data[:, :, 0:3] = provider.random_scale_point_cloud( data[:, :, 0:3] = provider.random_scale_point_cloud(data[:, :, 0:3])
data[:, :, 0:3])
data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.jitter_point_cloud(data[:, :, 0:3])
data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3]) data[:, :, 0:3] = provider.shift_point_cloud(data[:, :, 0:3])
data = torch.tensor(data) data = torch.tensor(data)
...@@ -83,11 +90,19 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -83,11 +90,19 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss += loss total_loss += loss
total_correct += correct total_correct += correct
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % (total_loss / num_batches), {
'AvgAcc': '%.5f' % (total_correct / count)}) "AvgLoss": "%.5f" % (total_loss / num_batches),
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / "AvgAcc": "%.5f" % (total_correct / count),
num_batches, total_correct / count, time.time() - start_time)) }
)
print(
"[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
total_loss / num_batches,
total_correct / count,
time.time() - start_time,
)
)
scheduler.step() scheduler.step()
...@@ -110,10 +125,12 @@ def evaluate(net, test_loader, dev): ...@@ -110,10 +125,12 @@ def evaluate(net, test_loader, dev):
total_correct += correct total_correct += correct
count += num_examples count += num_examples
tq.set_postfix({ tq.set_postfix({"AvgAcc": "%.5f" % (total_correct / count)})
'AvgAcc': '%.5f' % (total_correct / count)}) print(
print("[Test] AvgAcc: {:.5}, Time: {:.5}s".format( "[Test] AvgAcc: {:.5}, Time: {:.5}s".format(
total_correct / count, time.time() - start_time)) total_correct / count, time.time() - start_time
)
)
return total_correct / count return total_correct / count
...@@ -126,21 +143,29 @@ if args.load_model_path: ...@@ -126,21 +143,29 @@ if args.load_model_path:
opt = torch.optim.SGD( opt = torch.optim.SGD(
net.parameters(), net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
lr=0.01,
weight_decay=1e-4,
momentum=0.9
) )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs) opt, T_max=args.num_epochs
)
train_dataset = ModelNetDataLoader(local_path, 1024, split='train') train_dataset = ModelNetDataLoader(local_path, 1024, split="train")
test_dataset = ModelNetDataLoader(local_path, 1024, split='test') test_dataset = ModelNetDataLoader(local_path, 1024, split="test")
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True) test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
drop_last=True,
)
best_test_acc = 0 best_test_acc = 0
...@@ -153,6 +178,5 @@ for epoch in range(args.num_epochs): ...@@ -153,6 +178,5 @@ for epoch in range(args.num_epochs):
best_test_acc = test_acc best_test_acc = test_acc
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test acc: %.5f (best: %.5f)' % ( print("Current test acc: %.5f (best: %.5f)" % (test_acc, best_test_acc))
test_acc, best_test_acc))
print() print()
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import dgl
from functools import partial
import tqdm
import argparse import argparse
import time import time
from functools import partial
import numpy as np
import provider import provider
import torch
import torch.optim as optim
import tqdm
from pct import PartSegLoss, PointTransformerSeg
from ShapeNet import ShapeNet from ShapeNet import ShapeNet
from pct import PointTransformerSeg, PartSegLoss from torch.utils.data import DataLoader
import dgl
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset-path', type=str, default='') parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument('--load-model-path', type=str, default='') parser.add_argument("--load-model-path", type=str, default="")
parser.add_argument('--save-model-path', type=str, default='') parser.add_argument("--save-model-path", type=str, default="")
parser.add_argument('--num-epochs', type=int, default=500) parser.add_argument("--num-epochs", type=int, default=500)
parser.add_argument('--num-workers', type=int, default=8) parser.add_argument("--num-workers", type=int, default=8)
parser.add_argument('--batch-size', type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument('--tensorboard', action='store_true') parser.add_argument("--tensorboard", action="store_true")
args = parser.parse_args() args = parser.parse_args()
num_workers = args.num_workers num_workers = args.num_workers
...@@ -37,7 +37,8 @@ CustomDataLoader = partial( ...@@ -37,7 +37,8 @@ CustomDataLoader = partial(
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=True) drop_last=True,
)
def train(net, opt, scheduler, train_loader, dev): def train(net, opt, scheduler, train_loader, dev):
...@@ -59,7 +60,8 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -59,7 +60,8 @@ def train(net, opt, scheduler, train_loader, dev):
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
# An one-hot encoding for the object category # An one-hot encoding for the object category
cat_tensor = torch.tensor(eye_mat[cat_ind]).to( cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float) dev, dtype=torch.float
)
cat_tensor = cat_tensor.view(num_examples, 16, 1) cat_tensor = cat_tensor.view(num_examples, 16, 1)
logits = net(data, cat_tensor) logits = net(data, cat_tensor)
loss = L(logits, label) loss = L(logits, label)
...@@ -78,14 +80,17 @@ def train(net, opt, scheduler, train_loader, dev): ...@@ -78,14 +80,17 @@ def train(net, opt, scheduler, train_loader, dev):
AvgLoss = total_loss / num_batches AvgLoss = total_loss / num_batches
AvgAcc = total_correct / count AvgAcc = total_correct / count
tq.set_postfix({ tq.set_postfix(
'AvgLoss': '%.5f' % AvgLoss, {"AvgLoss": "%.5f" % AvgLoss, "AvgAcc": "%.5f" % AvgAcc}
'AvgAcc': '%.5f' % AvgAcc}) )
scheduler.step() scheduler.step()
end = time.time() end = time.time()
print("[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(total_loss / print(
num_batches, total_correct / count, end - start)) "[Train] AvgLoss: {:.5}, AvgAcc: {:.5}, Time: {:.5}s".format(
return data, preds, AvgLoss, AvgAcc, end-start total_loss / num_batches, total_correct / count, end - start
)
)
return data, preds, AvgLoss, AvgAcc, end - start
def mIoU(preds, label, cat, cat_miou, seg_classes): def mIoU(preds, label, cat, cat_miou, seg_classes):
...@@ -129,26 +134,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False): ...@@ -129,26 +134,36 @@ def evaluate(net, test_loader, dev, per_cat_verbose=False):
label = label.to(dev, dtype=torch.long) label = label.to(dev, dtype=torch.long)
cat_ind = [category_list.index(c) for c in cat] cat_ind = [category_list.index(c) for c in cat]
cat_tensor = torch.tensor(eye_mat[cat_ind]).to( cat_tensor = torch.tensor(eye_mat[cat_ind]).to(
dev, dtype=torch.float) dev, dtype=torch.float
cat_tensor = cat_tensor.view( )
num_examples, 16, 1) cat_tensor = cat_tensor.view(num_examples, 16, 1)
logits = net(data, cat_tensor) logits = net(data, cat_tensor)
_, preds = logits.max(1) _, preds = logits.max(1)
cat_miou = mIoU(preds.cpu().numpy(), cat_miou = mIoU(
label.view(num_examples, -1).cpu().numpy(), preds.cpu().numpy(),
cat, cat_miou, shapenet.seg_classes) label.view(num_examples, -1).cpu().numpy(),
cat,
cat_miou,
shapenet.seg_classes,
)
for _, v in cat_miou.items(): for _, v in cat_miou.items():
if v[1] > 0: if v[1] > 0:
miou += v[0] miou += v[0]
count += v[1] count += v[1]
per_cat_miou += v[0] / v[1] per_cat_miou += v[0] / v[1]
per_cat_count += 1 per_cat_count += 1
tq.set_postfix({ tq.set_postfix(
'mIoU': '%.5f' % (miou / count), {
'per Category mIoU': '%.5f' % (per_cat_miou / per_cat_count)}) "mIoU": "%.5f" % (miou / count),
print("[Test] mIoU: %.5f, per Category mIoU: %.5f" % "per Category mIoU": "%.5f"
(miou / count, per_cat_miou / per_cat_count)) % (per_cat_miou / per_cat_count),
}
)
print(
"[Test] mIoU: %.5f, per Category mIoU: %.5f"
% (miou / count, per_cat_miou / per_cat_count)
)
if per_cat_verbose: if per_cat_verbose:
print("-" * 60) print("-" * 60)
print("Per-Category mIoU:") print("Per-Category mIoU:")
...@@ -169,14 +184,12 @@ if args.load_model_path: ...@@ -169,14 +184,12 @@ if args.load_model_path:
net.load_state_dict(torch.load(args.load_model_path, map_location=dev)) net.load_state_dict(torch.load(args.load_model_path, map_location=dev))
opt = torch.optim.SGD( opt = torch.optim.SGD(
net.parameters(), net.parameters(), lr=0.01, weight_decay=1e-4, momentum=0.9
lr=0.01,
weight_decay=1e-4,
momentum=0.9
) )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=args.num_epochs) opt, T_max=args.num_epochs
)
L = PartSegLoss() L = PartSegLoss()
...@@ -190,20 +203,63 @@ if args.tensorboard: ...@@ -190,20 +203,63 @@ if args.tensorboard:
import torchvision import torchvision
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms from torchvision import datasets, transforms
writer = SummaryWriter() writer = SummaryWriter()
# Select 50 distinct colors for different parts # Select 50 distinct colors for different parts
color_map = torch.tensor([ color_map = torch.tensor(
[47, 79, 79], [139, 69, 19], [112, 128, 144], [85, 107, 47], [139, 0, 0], [ [
128, 128, 0], [72, 61, 139], [0, 128, 0], [188, 143, 143], [60, 179, 113], [47, 79, 79],
[205, 133, 63], [0, 139, 139], [70, 130, 180], [205, 92, 92], [154, 205, 50], [ [139, 69, 19],
0, 0, 139], [50, 205, 50], [250, 250, 250], [218, 165, 32], [139, 0, 139], [112, 128, 144],
[10, 10, 10], [176, 48, 96], [72, 209, 204], [153, 50, 204], [255, 69, 0], [ [85, 107, 47],
255, 145, 0], [0, 0, 205], [255, 255, 0], [0, 255, 0], [233, 150, 122], [139, 0, 0],
[220, 20, 60], [0, 191, 255], [160, 32, 240], [192, 192, 192], [173, 255, 47], [ [128, 128, 0],
218, 112, 214], [216, 191, 216], [255, 127, 80], [255, 0, 255], [100, 149, 237], [72, 61, 139],
[128, 128, 128], [221, 160, 221], [144, 238, 144], [123, 104, 238], [255, 160, 122], [ [0, 128, 0],
175, 238, 238], [238, 130, 238], [127, 255, 212], [255, 218, 185], [255, 105, 180], [188, 143, 143],
]) [60, 179, 113],
[205, 133, 63],
[0, 139, 139],
[70, 130, 180],
[205, 92, 92],
[154, 205, 50],
[0, 0, 139],
[50, 205, 50],
[250, 250, 250],
[218, 165, 32],
[139, 0, 139],
[10, 10, 10],
[176, 48, 96],
[72, 209, 204],
[153, 50, 204],
[255, 69, 0],
[255, 145, 0],
[0, 0, 205],
[255, 255, 0],
[0, 255, 0],
[233, 150, 122],
[220, 20, 60],
[0, 191, 255],
[160, 32, 240],
[192, 192, 192],
[173, 255, 47],
[218, 112, 214],
[216, 191, 216],
[255, 127, 80],
[255, 0, 255],
[100, 149, 237],
[128, 128, 128],
[221, 160, 221],
[144, 238, 144],
[123, 104, 238],
[255, 160, 122],
[175, 238, 238],
[238, 130, 238],
[127, 255, 212],
[255, 218, 185],
[255, 105, 180],
]
)
# paint each point according to its pred # paint each point according to its pred
...@@ -219,28 +275,38 @@ best_test_per_cat_miou = 0 ...@@ -219,28 +275,38 @@ best_test_per_cat_miou = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
print("Epoch #{}: ".format(epoch)) print("Epoch #{}: ".format(epoch))
data, preds, AvgLoss, AvgAcc, training_time = train( data, preds, AvgLoss, AvgAcc, training_time = train(
net, opt, scheduler, train_loader, dev) net, opt, scheduler, train_loader, dev
)
if (epoch + 1) % 5 == 0 or epoch == 0: if (epoch + 1) % 5 == 0 or epoch == 0:
test_miou, test_per_cat_miou = evaluate( test_miou, test_per_cat_miou = evaluate(net, test_loader, dev, True)
net, test_loader, dev, True)
if test_miou > best_test_miou: if test_miou > best_test_miou:
best_test_miou = test_miou best_test_miou = test_miou
best_test_per_cat_miou = test_per_cat_miou best_test_per_cat_miou = test_per_cat_miou
if args.save_model_path: if args.save_model_path:
torch.save(net.state_dict(), args.save_model_path) torch.save(net.state_dict(), args.save_model_path)
print('Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)' % ( print(
test_miou, best_test_miou, test_per_cat_miou, best_test_per_cat_miou)) "Current test mIoU: %.5f (best: %.5f), per-Category mIoU: %.5f (best: %.5f)"
% (
test_miou,
best_test_miou,
test_per_cat_miou,
best_test_per_cat_miou,
)
)
# Tensorboard # Tensorboard
if args.tensorboard: if args.tensorboard:
colored = paint(preds) colored = paint(preds)
writer.add_mesh('data', vertices=data, writer.add_mesh(
colors=colored, global_step=epoch) "data", vertices=data, colors=colored, global_step=epoch
writer.add_scalar('training time for one epoch', )
training_time, global_step=epoch) writer.add_scalar(
writer.add_scalar('AvgLoss', AvgLoss, global_step=epoch) "training time for one epoch", training_time, global_step=epoch
writer.add_scalar('AvgAcc', AvgAcc, global_step=epoch) )
writer.add_scalar("AvgLoss", AvgLoss, global_step=epoch)
writer.add_scalar("AvgAcc", AvgAcc, global_step=epoch)
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
writer.add_scalar('test mIoU', test_miou, global_step=epoch) writer.add_scalar("test mIoU", test_miou, global_step=epoch)
writer.add_scalar('best test mIoU', writer.add_scalar(
best_test_miou, global_step=epoch) "best test mIoU", best_test_miou, global_step=epoch
)
print() print()
import numpy as np
import warnings
import os import os
import warnings
import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def pc_normalize(pc): def pc_normalize(pc):
...@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint): ...@@ -43,11 +45,18 @@ def farthest_point_sample(point, npoint):
class ModelNetDataLoader(Dataset): class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', fps=False, def __init__(
normal_channel=True, cache_size=15000): self,
root,
npoint=1024,
split="train",
fps=False,
normal_channel=True,
cache_size=15000,
):
""" """
Input: Input:
root: the root path to the local data files root: the root path to the local data files
npoint: number of points from each cloud npoint: number of points from each cloud
split: which split of the data, 'train' or 'test' split: which split of the data, 'train' or 'test'
fps: whether to sample points with farthest point sampler fps: whether to sample points with farthest point sampler
...@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset): ...@@ -57,24 +66,34 @@ class ModelNetDataLoader(Dataset):
self.root = root self.root = root
self.npoints = npoint self.npoints = npoint
self.fps = fps self.fps = fps
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
self.cat = [line.rstrip() for line in open(self.catfile)] self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat)))) self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel self.normal_channel = normal_channel
shape_ids = {} shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open( shape_ids["train"] = [
os.path.join(self.root, 'modelnet40_train.txt'))] line.rstrip()
shape_ids['test'] = [line.rstrip() for line in open( for line in open(os.path.join(self.root, "modelnet40_train.txt"))
os.path.join(self.root, 'modelnet40_test.txt'))] ]
shape_ids["test"] = [
assert (split == 'train' or split == 'test') line.rstrip()
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] for line in open(os.path.join(self.root, "modelnet40_test.txt"))
]
assert split == "train" or split == "test"
shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple # list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i self.datapath = [
in range(len(shape_ids[split]))] (
print('The size of %s data is %d' % (split, len(self.datapath))) shape_names[i],
os.path.join(self.root, shape_names[i], shape_ids[split][i])
+ ".txt",
)
for i in range(len(shape_ids[split]))
]
print("The size of %s data is %d" % (split, len(self.datapath)))
self.cache_size = cache_size self.cache_size = cache_size
self.cache = {} self.cache = {}
...@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset): ...@@ -89,11 +108,11 @@ class ModelNetDataLoader(Dataset):
fn = self.datapath[index] fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]] cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32) cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
if self.fps: if self.fps:
point_set = farthest_point_sample(point_set, self.npoints) point_set = farthest_point_sample(point_set, self.npoints)
else: else:
point_set = point_set[0:self.npoints, :] point_set = point_set[0 : self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
......
import os, json, tqdm import json
import numpy as np import os
import dgl
from zipfile import ZipFile from zipfile import ZipFile
from torch.utils.data import Dataset
import numpy as np
import tqdm
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, get_download_dir from dgl.data.utils import download, get_download_dir
class ShapeNet(object): class ShapeNet(object):
def __init__(self, num_points=2048, normal_channel=True): def __init__(self, num_points=2048, normal_channel=True):
self.num_points = num_points self.num_points = num_points
...@@ -13,8 +18,13 @@ class ShapeNet(object): ...@@ -13,8 +18,13 @@ class ShapeNet(object):
SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
download_path = get_download_dir() download_path = get_download_dir()
data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip" data_filename = (
data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal") "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
)
data_path = os.path.join(
download_path,
"shapenetcore_partanno_segmentation_benchmark_v0_normal",
)
if not os.path.exists(data_path): if not os.path.exists(data_path):
local_path = os.path.join(download_path, data_filename) local_path = os.path.join(download_path, data_filename)
if not os.path.exists(local_path): if not os.path.exists(local_path):
...@@ -24,52 +34,72 @@ class ShapeNet(object): ...@@ -24,52 +34,72 @@ class ShapeNet(object):
synset_file = "synsetoffset2category.txt" synset_file = "synsetoffset2category.txt"
with open(os.path.join(data_path, synset_file)) as f: with open(os.path.join(data_path, synset_file)) as f:
synset = [t.split('\n')[0].split('\t') for t in f.readlines()] synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
self.synset_dict = {} self.synset_dict = {}
for syn in synset: for syn in synset:
self.synset_dict[syn[1]] = syn[0] self.synset_dict[syn[1]] = syn[0]
self.seg_classes = {'Airplane': [0, 1, 2, 3], self.seg_classes = {
'Bag': [4, 5], "Airplane": [0, 1, 2, 3],
'Cap': [6, 7], "Bag": [4, 5],
'Car': [8, 9, 10, 11], "Cap": [6, 7],
'Chair': [12, 13, 14, 15], "Car": [8, 9, 10, 11],
'Earphone': [16, 17, 18], "Chair": [12, 13, 14, 15],
'Guitar': [19, 20, 21], "Earphone": [16, 17, 18],
'Knife': [22, 23], "Guitar": [19, 20, 21],
'Lamp': [24, 25, 26, 27], "Knife": [22, 23],
'Laptop': [28, 29], "Lamp": [24, 25, 26, 27],
'Motorbike': [30, 31, 32, 33, 34, 35], "Laptop": [28, 29],
'Mug': [36, 37], "Motorbike": [30, 31, 32, 33, 34, 35],
'Pistol': [38, 39, 40], "Mug": [36, 37],
'Rocket': [41, 42, 43], "Pistol": [38, 39, 40],
'Skateboard': [44, 45, 46], "Rocket": [41, 42, 43],
'Table': [47, 48, 49]} "Skateboard": [44, 45, 46],
"Table": [47, 48, 49],
train_split_json = 'shuffled_train_file_list.json' }
val_split_json = 'shuffled_val_file_list.json'
test_split_json = 'shuffled_test_file_list.json' train_split_json = "shuffled_train_file_list.json"
split_path = os.path.join(data_path, 'train_test_split') val_split_json = "shuffled_val_file_list.json"
test_split_json = "shuffled_test_file_list.json"
split_path = os.path.join(data_path, "train_test_split")
with open(os.path.join(split_path, train_split_json)) as f: with open(os.path.join(split_path, train_split_json)) as f:
tmp = f.read() tmp = f.read()
self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.train_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, val_split_json)) as f: with open(os.path.join(split_path, val_split_json)) as f:
tmp = f.read() tmp = f.read()
self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.val_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
with open(os.path.join(split_path, test_split_json)) as f: with open(os.path.join(split_path, test_split_json)) as f:
tmp = f.read() tmp = f.read()
self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)] self.test_file_list = [
os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
for t in json.loads(tmp)
]
def train(self): def train(self):
return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "train", self.num_points, self.normal_channel
)
def valid(self): def valid(self):
return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "valid", self.num_points, self.normal_channel
)
def trainval(self): def trainval(self):
return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "trainval", self.num_points, self.normal_channel
)
def test(self): def test(self):
return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel) return ShapeNetDataset(
self, "test", self.num_points, self.normal_channel
)
class ShapeNetDataset(Dataset): class ShapeNetDataset(Dataset):
def __init__(self, shapenet, mode, num_points, normal_channel=True): def __init__(self, shapenet, mode, num_points, normal_channel=True):
...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset): ...@@ -81,13 +111,13 @@ class ShapeNetDataset(Dataset):
else: else:
self.dim = 6 self.dim = 6
if mode == 'train': if mode == "train":
self.file_list = shapenet.train_file_list self.file_list = shapenet.train_file_list
elif mode == 'valid': elif mode == "valid":
self.file_list = shapenet.val_file_list self.file_list = shapenet.val_file_list
elif mode == 'test': elif mode == "test":
self.file_list = shapenet.test_file_list self.file_list = shapenet.test_file_list
elif mode == 'trainval': elif mode == "trainval":
self.file_list = shapenet.train_file_list + shapenet.val_file_list self.file_list = shapenet.train_file_list + shapenet.val_file_list
else: else:
raise "Not supported `mode`" raise "Not supported `mode`"
...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset): ...@@ -95,32 +125,36 @@ class ShapeNetDataset(Dataset):
data_list = [] data_list = []
label_list = [] label_list = []
category_list = [] category_list = []
print('Loading data from split ' + self.mode) print("Loading data from split " + self.mode)
for fn in tqdm.tqdm(self.file_list, ascii=True): for fn in tqdm.tqdm(self.file_list, ascii=True):
with open(fn) as f: with open(fn) as f:
data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(np.float) data = np.array(
data_list.append(data[:, 0:self.dim]) [t.split("\n")[0].split(" ") for t in f.readlines()]
).astype(np.float)
data_list.append(data[:, 0 : self.dim])
label_list.append(data[:, 6].astype(np.int)) label_list.append(data[:, 6].astype(np.int))
category_list.append(shapenet.synset_dict[fn.split('/')[-2]]) category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
self.data = data_list self.data = data_list
self.label = label_list self.label = label_list
self.category = category_list self.category = category_list
def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3): def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size]) xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size]) xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
x = np.add(np.multiply(x, xyz1), xyz2).astype('float32') x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
return x return x
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __getitem__(self, i): def __getitem__(self, i):
inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True) inds = np.random.choice(
x = self.data[i][inds,:self.dim] self.data[i].shape[0], self.num_points, replace=True
)
x = self.data[i][inds, : self.dim]
y = self.label[i][inds] y = self.label[i][inds]
cat = self.category[i] cat = self.category[i]
if self.mode == 'train': if self.mode == "train":
x = self.translate(x, size=self.dim) x = self.translate(x, size=self.dim)
x = x.astype(np.float) x = x.astype(np.float)
y = y.astype(np.int) y = y.astype(np.int)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment