Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import argparse import argparse
import warnings
import torch as th import torch as th
from dataset import load
import dgl import dgl
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
import warnings
from dataset import load warnings.filterwarnings("ignore")
warnings.filterwarnings('ignore')
from utils import linearsvc
from model import MVGRL from model import MVGRL
from utils import linearsvc
parser = argparse.ArgumentParser(description='mvgrl') parser = argparse.ArgumentParser(description="mvgrl")
parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.') parser.add_argument(
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using cpu.') "--dataname", type=str, default="MUTAG", help="Name of dataset."
parser.add_argument('--epochs', type=int, default=200, help=' Number of training periods.') )
parser.add_argument('--patience', type=int, default=20, help='Early stopping steps.') parser.add_argument(
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate of mvgrl.') "--gpu", type=int, default=-1, help="GPU index. Default: -1, using cpu."
parser.add_argument('--wd', type=float, default=0., help='Weight decay of mvgrl.') )
parser.add_argument('--batch_size', type=int, default=64, help='Batch size.') parser.add_argument(
parser.add_argument('--n_layers', type=int, default=4, help='Number of GNN layers.') "--epochs", type=int, default=200, help=" Number of training periods."
parser.add_argument("--hid_dim", type=int, default=32, help='Hidden layer dim.') )
parser.add_argument(
"--patience", type=int, default=20, help="Early stopping steps."
)
parser.add_argument(
"--lr", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
"--wd", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
parser.add_argument(
"--n_layers", type=int, default=4, help="Number of GNN layers."
)
parser.add_argument("--hid_dim", type=int, default=32, help="Hidden layer dim.")
args = parser.parse_args() args = parser.parse_args()
# check cuda # check cuda
if args.gpu != -1 and th.cuda.is_available(): if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu) args.device = "cuda:{}".format(args.gpu)
else: else:
args.device = 'cpu' args.device = "cpu"
def collate(samples): def collate(samples):
''' collate function for building the graph dataloader''' """collate function for building the graph dataloader"""
graphs, diff_graphs, labels = map(list, zip(*samples)) graphs, diff_graphs, labels = map(list, zip(*samples))
# generate batched graphs and labels # generate batched graphs and labels
...@@ -45,30 +60,33 @@ def collate(samples): ...@@ -45,30 +60,33 @@ def collate(samples):
graph_id = th.arange(n_graphs) graph_id = th.arange(n_graphs)
graph_id = dgl.broadcast_nodes(batched_graph, graph_id) graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
batched_graph.ndata['graph_id'] = graph_id batched_graph.ndata["graph_id"] = graph_id
return batched_graph, batched_diff_graph, batched_labels return batched_graph, batched_diff_graph, batched_labels
if __name__ == '__main__':
if __name__ == "__main__":
# Step 1: Prepare data =================================================================== # # Step 1: Prepare data =================================================================== #
dataset = load(args.dataname) dataset = load(args.dataname)
graphs, diff_graphs, labels = map(list, zip(*dataset)) graphs, diff_graphs, labels = map(list, zip(*dataset))
print('Number of graphs:', len(graphs)) print("Number of graphs:", len(graphs))
# generate a full-graph with all examples for evaluation # generate a full-graph with all examples for evaluation
wholegraph = dgl.batch(graphs) wholegraph = dgl.batch(graphs)
whole_dg = dgl.batch(diff_graphs) whole_dg = dgl.batch(diff_graphs)
# create dataloader for batch training # create dataloader for batch training
dataloader = GraphDataLoader(dataset, dataloader = GraphDataLoader(
batch_size=args.batch_size, dataset,
collate_fn=collate, batch_size=args.batch_size,
drop_last=False, collate_fn=collate,
shuffle=True) drop_last=False,
shuffle=True,
)
in_dim = wholegraph.ndata['feat'].shape[1] in_dim = wholegraph.ndata["feat"].shape[1]
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = MVGRL(in_dim, args.hid_dim, args.n_layers) model = MVGRL(in_dim, args.hid_dim, args.n_layers)
...@@ -77,19 +95,19 @@ if __name__ == '__main__': ...@@ -77,19 +95,19 @@ if __name__ == '__main__':
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr) optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
print('===== Before training ======') print("===== Before training ======")
wholegraph = wholegraph.to(args.device) wholegraph = wholegraph.to(args.device)
whole_dg = whole_dg.to(args.device) whole_dg = whole_dg.to(args.device)
wholefeat = wholegraph.ndata.pop('feat') wholefeat = wholegraph.ndata.pop("feat")
whole_weight = whole_dg.edata.pop('edge_weight') whole_weight = whole_dg.edata.pop("edge_weight")
embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight) embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
lbls = th.LongTensor(labels) lbls = th.LongTensor(labels)
acc_mean, acc_std = linearsvc(embs, lbls) acc_mean, acc_std = linearsvc(embs, lbls)
print('accuracy_mean, {:.4f}'.format(acc_mean)) print("accuracy_mean, {:.4f}".format(acc_mean))
best = float('inf') best = float("inf")
cnt_wait = 0 cnt_wait = 0
# Step 4: Training epochs =============================================================== # # Step 4: Training epochs =============================================================== #
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -100,9 +118,9 @@ if __name__ == '__main__': ...@@ -100,9 +118,9 @@ if __name__ == '__main__':
graph = graph.to(args.device) graph = graph.to(args.device)
diff_graph = diff_graph.to(args.device) diff_graph = diff_graph.to(args.device)
feat = graph.ndata['feat'] feat = graph.ndata["feat"]
graph_id = graph.ndata['graph_id'] graph_id = graph.ndata["graph_id"]
edge_weight = diff_graph.edata['edge_weight'] edge_weight = diff_graph.edata["edge_weight"]
n_graph = label.shape[0] n_graph = label.shape[0]
optimizer.zero_grad() optimizer.zero_grad()
...@@ -111,25 +129,25 @@ if __name__ == '__main__': ...@@ -111,25 +129,25 @@ if __name__ == '__main__':
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all)) print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))
if loss < best: if loss < best:
best = loss best = loss
best_t = epoch best_t = epoch
cnt_wait = 0 cnt_wait = 0
th.save(model.state_dict(), f'{args.dataname}.pkl') th.save(model.state_dict(), f"{args.dataname}.pkl")
else: else:
cnt_wait += 1 cnt_wait += 1
if cnt_wait == args.patience: if cnt_wait == args.patience:
print('Early stopping') print("Early stopping")
break break
print('Training End') print("Training End")
# Step 5: Linear evaluation ========================================================== # # Step 5: Linear evaluation ========================================================== #
model.load_state_dict(th.load(f'{args.dataname}.pkl')) model.load_state_dict(th.load(f"{args.dataname}.pkl"))
embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight) embs = model.get_embedding(wholegraph, whole_dg, wholefeat, whole_weight)
acc_mean, acc_std = linearsvc(embs, lbls) acc_mean, acc_std = linearsvc(embs, lbls)
print('accuracy_mean, {:.4f}'.format(acc_mean)) print("accuracy_mean, {:.4f}".format(acc_mean))
\ No newline at end of file
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from utils import local_global_loss_
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.glob import SumPooling from dgl.nn.pytorch.glob import SumPooling
from utils import local_global_loss_
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, in_dim, out_dim): def __init__(self, in_dim, out_dim):
...@@ -15,7 +15,7 @@ class MLP(nn.Module): ...@@ -15,7 +15,7 @@ class MLP(nn.Module):
nn.Linear(out_dim, out_dim), nn.Linear(out_dim, out_dim),
nn.PReLU(), nn.PReLU(),
nn.Linear(out_dim, out_dim), nn.Linear(out_dim, out_dim),
nn.PReLU() nn.PReLU(),
) )
self.linear_shortcut = nn.Linear(in_dim, out_dim) self.linear_shortcut = nn.Linear(in_dim, out_dim)
...@@ -30,13 +30,25 @@ class GCN(nn.Module): ...@@ -30,13 +30,25 @@ class GCN(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(GraphConv(in_dim, out_dim, bias=False, norm=norm, activation = nn.PReLU())) self.layers.append(
GraphConv(
in_dim, out_dim, bias=False, norm=norm, activation=nn.PReLU()
)
)
self.pooling = SumPooling() self.pooling = SumPooling()
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.layers.append(GraphConv(out_dim, out_dim, bias=False, norm=norm, activation = nn.PReLU())) self.layers.append(
GraphConv(
def forward(self, graph, feat, edge_weight = None): out_dim,
out_dim,
bias=False,
norm=norm,
activation=nn.PReLU(),
)
)
def forward(self, graph, feat, edge_weight=None):
h = self.layers[0](graph, feat, edge_weight=edge_weight) h = self.layers[0](graph, feat, edge_weight=edge_weight)
hg = self.pooling(graph, h) hg = self.pooling(graph, h)
...@@ -70,17 +82,19 @@ class MVGRL(nn.Module): ...@@ -70,17 +82,19 @@ class MVGRL(nn.Module):
edge_weight: tensor edge_weight: tensor
Edge weight of the diffusion graph Edge weight of the diffusion graph
""" """
def __init__(self, in_dim, out_dim, num_layers): def __init__(self, in_dim, out_dim, num_layers):
super(MVGRL, self).__init__() super(MVGRL, self).__init__()
self.local_mlp = MLP(out_dim, out_dim) self.local_mlp = MLP(out_dim, out_dim)
self.global_mlp = MLP(num_layers * out_dim, out_dim) self.global_mlp = MLP(num_layers * out_dim, out_dim)
self.encoder1 = GCN(in_dim, out_dim, num_layers, norm='both') self.encoder1 = GCN(in_dim, out_dim, num_layers, norm="both")
self.encoder2 = GCN(in_dim, out_dim, num_layers, norm='none') self.encoder2 = GCN(in_dim, out_dim, num_layers, norm="none")
def get_embedding(self, graph1, graph2, feat, edge_weight): def get_embedding(self, graph1, graph2, feat, edge_weight):
local_v1, global_v1 = self.encoder1(graph1, feat) local_v1, global_v1 = self.encoder1(graph1, feat)
local_v2, global_v2 = self.encoder2(graph2, feat, edge_weight=edge_weight) local_v2, global_v2 = self.encoder2(
graph2, feat, edge_weight=edge_weight
)
global_v1 = self.global_mlp(global_v1) global_v1 = self.global_mlp(global_v1)
global_v2 = self.global_mlp(global_v2) global_v2 = self.global_mlp(global_v2)
...@@ -90,7 +104,9 @@ class MVGRL(nn.Module): ...@@ -90,7 +104,9 @@ class MVGRL(nn.Module):
def forward(self, graph1, graph2, feat, edge_weight, graph_id): def forward(self, graph1, graph2, feat, edge_weight, graph_id):
# calculate node embeddings and graph embeddings # calculate node embeddings and graph embeddings
local_v1, global_v1 = self.encoder1(graph1, feat) local_v1, global_v1 = self.encoder1(graph1, feat)
local_v2, global_v2 = self.encoder2(graph2, feat, edge_weight=edge_weight) local_v2, global_v2 = self.encoder2(
graph2, feat, edge_weight=edge_weight
)
local_v1 = self.local_mlp(local_v1) local_v1 = self.local_mlp(local_v1)
local_v2 = self.local_mlp(local_v2) local_v2 = self.local_mlp(local_v2)
...@@ -105,8 +121,3 @@ class MVGRL(nn.Module): ...@@ -105,8 +121,3 @@ class MVGRL(nn.Module):
loss = loss1 + loss2 loss = loss1 + loss2
return loss return loss
''' Code adapted from https://github.com/fanyun-sun/InfoGraph ''' """ Code adapted from https://github.com/fanyun-sun/InfoGraph """
import torch as th
import torch.nn.functional as F
import math import math
import numpy as np
from sklearn.svm import LinearSVC import numpy as np
import torch as th
import torch.nn.functional as F
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.svm import LinearSVC
def linearsvc(embeds, labels): def linearsvc(embeds, labels):
x = embeds.cpu().numpy() x = embeds.cpu().numpy()
y = labels.cpu().numpy() y = labels.cpu().numpy()
params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]} params = {"C": [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None) kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
accuracies = [] accuracies = []
for train_index, test_index in kf.split(x, y): for train_index, test_index in kf.split(x, y):
x_train, x_test = x[train_index], x[test_index] x_train, x_test = x[train_index], x[test_index]
y_train, y_test = y[train_index], y[test_index] y_train, y_test = y[train_index], y[test_index]
classifier = GridSearchCV(LinearSVC(), params, cv=5, scoring='accuracy', verbose=0) classifier = GridSearchCV(
LinearSVC(), params, cv=5, scoring="accuracy", verbose=0
)
classifier.fit(x_train, y_train) classifier.fit(x_train, y_train)
accuracies.append(accuracy_score(y_test, classifier.predict(x_test))) accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
return np.mean(accuracies), np.std(accuracies) return np.mean(accuracies), np.std(accuracies)
def get_positive_expectation(p_samples, average=True): def get_positive_expectation(p_samples, average=True):
"""Computes the positive part of a JS Divergence. """Computes the positive part of a JS Divergence.
Args: Args:
...@@ -31,8 +34,8 @@ def get_positive_expectation(p_samples, average=True): ...@@ -31,8 +34,8 @@ def get_positive_expectation(p_samples, average=True):
Returns: Returns:
th.Tensor th.Tensor
""" """
log_2 = math.log(2.) log_2 = math.log(2.0)
Ep = log_2 - F.softplus(- p_samples) Ep = log_2 - F.softplus(-p_samples)
if average: if average:
return Ep.mean() return Ep.mean()
...@@ -48,7 +51,7 @@ def get_negative_expectation(q_samples, average=True): ...@@ -48,7 +51,7 @@ def get_negative_expectation(q_samples, average=True):
Returns: Returns:
th.Tensor th.Tensor
""" """
log_2 = math.log(2.) log_2 = math.log(2.0)
Eq = F.softplus(-q_samples) + q_samples - log_2 Eq = F.softplus(-q_samples) + q_samples - log_2
if average: if average:
...@@ -69,8 +72,8 @@ def local_global_loss_(l_enc, g_enc, graph_id): ...@@ -69,8 +72,8 @@ def local_global_loss_(l_enc, g_enc, graph_id):
for nodeidx, graphidx in enumerate(graph_id): for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1. pos_mask[nodeidx][graphidx] = 1.0
neg_mask[nodeidx][graphidx] = 0. neg_mask[nodeidx][graphidx] = 0.0
res = th.mm(l_enc, g_enc.t()) res = th.mm(l_enc, g_enc.t())
......
''' Code adapted from https://github.com/kavehhassani/mvgrl ''' """ Code adapted from https://github.com/kavehhassani/mvgrl """
import networkx as nx
import numpy as np import numpy as np
import torch as th
import scipy.sparse as sp import scipy.sparse as sp
import torch as th
from scipy.linalg import fractional_matrix_power, inv from scipy.linalg import fractional_matrix_power, inv
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
import networkx as nx
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler
import dgl
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from dgl.nn import APPNPConv from dgl.nn import APPNPConv
def preprocess_features(features): def preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation""" """Row-normalize feature matrix and convert to tuple representation"""
rowsum = np.array(features.sum(1)) rowsum = np.array(features.sum(1))
r_inv = np.power(rowsum, -1).flatten() r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0. r_inv[np.isinf(r_inv)] = 0.0
r_mat_inv = sp.diags(r_inv) r_mat_inv = sp.diags(r_inv)
features = r_mat_inv.dot(features) features = r_mat_inv.dot(features)
if isinstance(features, np.ndarray): if isinstance(features, np.ndarray):
...@@ -52,22 +51,24 @@ def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True): ...@@ -52,22 +51,24 @@ def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):
d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii
dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) dinv = fractional_matrix_power(d, -0.5) # D^(-1/2)
at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2)
return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 return alpha * inv(
(np.eye(a.shape[0]) - (1 - alpha) * at)
) # a(I_n-(1-a)A~)^-1
def process_dataset(name, epsilon): def process_dataset(name, epsilon):
if name == 'cora': if name == "cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif name == 'citeseer': elif name == "citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
graph = dataset[0] graph = dataset[0]
feat = graph.ndata.pop('feat') feat = graph.ndata.pop("feat")
label = graph.ndata.pop('label') label = graph.ndata.pop("label")
train_mask = graph.ndata.pop('train_mask') train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop('val_mask') val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop('test_mask') test_mask = graph.ndata.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze() val_idx = th.nonzero(val_mask, as_tuple=False).squeeze()
...@@ -75,12 +76,12 @@ def process_dataset(name, epsilon): ...@@ -75,12 +76,12 @@ def process_dataset(name, epsilon):
nx_g = dgl.to_networkx(graph) nx_g = dgl.to_networkx(graph)
print('computing ppr') print("computing ppr")
diff_adj = compute_ppr(nx_g, 0.2) diff_adj = compute_ppr(nx_g, 0.2)
print('computing end') print("computing end")
if name == 'citeseer': if name == "citeseer":
print('additional processing') print("additional processing")
feat = th.tensor(preprocess_features(feat.numpy())).float() feat = th.tensor(preprocess_features(feat.numpy())).float()
diff_adj[diff_adj < epsilon] = 0 diff_adj[diff_adj < epsilon] = 0
scaler = MinMaxScaler() scaler = MinMaxScaler()
...@@ -93,19 +94,29 @@ def process_dataset(name, epsilon): ...@@ -93,19 +94,29 @@ def process_dataset(name, epsilon):
graph = graph.add_self_loop() graph = graph.add_self_loop()
return graph, diff_graph, feat, label, train_idx, val_idx, test_idx, diff_weight return (
graph,
diff_graph,
feat,
label,
train_idx,
val_idx,
test_idx,
diff_weight,
)
def process_dataset_appnp(epsilon): def process_dataset_appnp(epsilon):
k = 20 k = 20
alpha = 0.2 alpha = 0.2
dataset = PubmedGraphDataset() dataset = PubmedGraphDataset()
graph = dataset[0] graph = dataset[0]
feat = graph.ndata.pop('feat') feat = graph.ndata.pop("feat")
label = graph.ndata.pop('label') label = graph.ndata.pop("label")
train_mask = graph.ndata.pop('train_mask') train_mask = graph.ndata.pop("train_mask")
val_mask = graph.ndata.pop('val_mask') val_mask = graph.ndata.pop("val_mask")
test_mask = graph.ndata.pop('test_mask') test_mask = graph.ndata.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze() val_idx = th.nonzero(val_mask, as_tuple=False).squeeze()
...@@ -123,4 +134,13 @@ def process_dataset_appnp(epsilon): ...@@ -123,4 +134,13 @@ def process_dataset_appnp(epsilon):
diff_weight = diff_adj[diff_edges] diff_weight = diff_adj[diff_edges]
diff_graph = dgl.graph(diff_edges) diff_graph = dgl.graph(diff_edges)
return graph, diff_graph, feat, label, train_idx, val_idx, test_idx, diff_weight return (
\ No newline at end of file graph,
diff_graph,
feat,
label,
train_idx,
val_idx,
test_idx,
diff_weight,
)
import argparse import argparse
import warnings
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import warnings warnings.filterwarnings("ignore")
warnings.filterwarnings('ignore')
from dataset import process_dataset from dataset import process_dataset
from model import MVGRL, LogReg from model import MVGRL, LogReg
parser = argparse.ArgumentParser(description='mvgrl') parser = argparse.ArgumentParser(description="mvgrl")
parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.') parser.add_argument(
parser.add_argument('--gpu', type=int, default=0, help='GPU index. Default: -1, using cpu.') "--dataname", type=str, default="cora", help="Name of dataset."
parser.add_argument('--epochs', type=int, default=500, help='Training epochs.') )
parser.add_argument('--patience', type=int, default=20, help='Patient epochs to wait before early stopping.') parser.add_argument(
parser.add_argument('--lr1', type=float, default=0.001, help='Learning rate of mvgrl.') "--gpu", type=int, default=0, help="GPU index. Default: -1, using cpu."
parser.add_argument('--lr2', type=float, default=0.01, help='Learning rate of linear evaluator.') )
parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl.') parser.add_argument("--epochs", type=int, default=500, help="Training epochs.")
parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.') parser.add_argument(
parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.') "--patience",
parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') type=int,
default=20,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument(
"--lr1", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
"--lr2", type=float, default=0.01, help="Learning rate of linear evaluator."
)
parser.add_argument(
"--wd1", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument(
"--wd2", type=float, default=0.0, help="Weight decay of linear evaluator."
)
parser.add_argument(
"--epsilon",
type=float,
default=0.01,
help="Edge mask threshold of diffusion graph.",
)
parser.add_argument(
"--hid_dim", type=int, default=512, help="Hidden layer dim."
)
args = parser.parse_args() args = parser.parse_args()
# check cuda # check cuda
if args.gpu != -1 and th.cuda.is_available(): if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu) args.device = "cuda:{}".format(args.gpu)
else: else:
args.device = 'cpu' args.device = "cpu"
if __name__ == '__main__': if __name__ == "__main__":
print(args) print(args)
# Step 1: Prepare data =================================================================== # # Step 1: Prepare data =================================================================== #
graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight = process_dataset(args.dataname, args.epsilon) (
graph,
diff_graph,
feat,
label,
train_idx,
val_idx,
test_idx,
edge_weight,
) = process_dataset(args.dataname, args.epsilon)
n_feat = feat.shape[1] n_feat = feat.shape[1]
n_classes = np.unique(label).shape[0] n_classes = np.unique(label).shape[0]
...@@ -60,11 +93,13 @@ if __name__ == '__main__': ...@@ -60,11 +93,13 @@ if __name__ == '__main__':
lbl = lbl.to(args.device) lbl = lbl.to(args.device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1) optimizer = th.optim.Adam(
model.parameters(), lr=args.lr1, weight_decay=args.wd1
)
loss_fn = nn.BCEWithLogitsLoss() loss_fn = nn.BCEWithLogitsLoss()
# Step 4: Training epochs ================================================================ # # Step 4: Training epochs ================================================================ #
best = float('inf') best = float("inf")
cnt_wait = 0 cnt_wait = 0
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
...@@ -80,20 +115,20 @@ if __name__ == '__main__': ...@@ -80,20 +115,20 @@ if __name__ == '__main__':
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item())) print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, loss.item()))
if loss < best: if loss < best:
best = loss best = loss
cnt_wait = 0 cnt_wait = 0
th.save(model.state_dict(), 'model.pkl') th.save(model.state_dict(), "model.pkl")
else: else:
cnt_wait += 1 cnt_wait += 1
if cnt_wait == args.patience: if cnt_wait == args.patience:
print('Early stopping') print("Early stopping")
break break
model.load_state_dict(th.load('model.pkl')) model.load_state_dict(th.load("model.pkl"))
embeds = model.get_embedding(graph, diff_graph, feat, edge_weight) embeds = model.get_embedding(graph, diff_graph, feat, edge_weight)
train_embs = embeds[train_idx] train_embs = embeds[train_idx]
...@@ -107,7 +142,9 @@ if __name__ == '__main__': ...@@ -107,7 +142,9 @@ if __name__ == '__main__':
# Step 5: Linear evaluation ========================================================== # # Step 5: Linear evaluation ========================================================== #
for _ in range(5): for _ in range(5):
model = LogReg(args.hid_dim, n_classes) model = LogReg(args.hid_dim, n_classes)
opt = th.optim.Adam(model.parameters(), lr=args.lr2, weight_decay=args.wd2) opt = th.optim.Adam(
model.parameters(), lr=args.lr2, weight_decay=args.wd2
)
model = model.to(args.device) model = model.to(args.device)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
......
import argparse import argparse
import random
import warnings
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import random
import dgl import dgl
import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings("ignore")
from dataset import process_dataset, process_dataset_appnp from dataset import process_dataset, process_dataset_appnp
from model import MVGRL, LogReg from model import MVGRL, LogReg
parser = argparse.ArgumentParser(description='mvgrl') parser = argparse.ArgumentParser(description="mvgrl")
parser.add_argument('--dataname', type=str, default='cora', help='Name of dataset.') parser.add_argument(
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using cpu.') "--dataname", type=str, default="cora", help="Name of dataset."
parser.add_argument('--epochs', type=int, default=500, help='Training epochs.') )
parser.add_argument('--patience', type=int, default=20, help='Patient epochs to wait before early stopping.') parser.add_argument(
parser.add_argument('--lr1', type=float, default=0.001, help='Learning rate of mvgrl.') "--gpu", type=int, default=-1, help="GPU index. Default: -1, using cpu."
parser.add_argument('--lr2', type=float, default=0.01, help='Learning rate of linear evaluator.') )
parser.add_argument('--wd1', type=float, default=0., help='Weight decay of mvgrl.') parser.add_argument("--epochs", type=int, default=500, help="Training epochs.")
parser.add_argument('--wd2', type=float, default=0., help='Weight decay of linear evaluator.') parser.add_argument(
parser.add_argument('--epsilon', type=float, default=0.01, help='Edge mask threshold of diffusion graph.') "--patience",
parser.add_argument("--hid_dim", type=int, default=512, help='Hidden layer dim.') type=int,
parser.add_argument("--sample_size", type=int, default=2000, help='Subgraph size.') default=20,
help="Patient epochs to wait before early stopping.",
)
parser.add_argument(
"--lr1", type=float, default=0.001, help="Learning rate of mvgrl."
)
parser.add_argument(
"--lr2", type=float, default=0.01, help="Learning rate of linear evaluator."
)
parser.add_argument(
"--wd1", type=float, default=0.0, help="Weight decay of mvgrl."
)
parser.add_argument(
"--wd2", type=float, default=0.0, help="Weight decay of linear evaluator."
)
parser.add_argument(
"--epsilon",
type=float,
default=0.01,
help="Edge mask threshold of diffusion graph.",
)
parser.add_argument(
"--hid_dim", type=int, default=512, help="Hidden layer dim."
)
parser.add_argument(
"--sample_size", type=int, default=2000, help="Subgraph size."
)
args = parser.parse_args() args = parser.parse_args()
# check cuda # check cuda
if args.gpu != -1 and th.cuda.is_available(): if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu) args.device = "cuda:{}".format(args.gpu)
else: else:
args.device = 'cpu' args.device = "cpu"
if __name__ == '__main__': if __name__ == "__main__":
print(args) print(args)
# Step 1: Prepare data =================================================================== # # Step 1: Prepare data =================================================================== #
if args.dataname == 'pubmed': if args.dataname == "pubmed":
graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight = process_dataset_appnp(args.epsilon) (
graph,
diff_graph,
feat,
label,
train_idx,
val_idx,
test_idx,
edge_weight,
) = process_dataset_appnp(args.epsilon)
else: else:
graph, diff_graph, feat, label, train_idx, val_idx, test_idx, edge_weight = process_dataset(args.dataname, args.epsilon) (
graph,
diff_graph,
feat,
label,
train_idx,
val_idx,
test_idx,
edge_weight,
) = process_dataset(args.dataname, args.epsilon)
edge_weight = th.tensor(edge_weight).float() edge_weight = th.tensor(edge_weight).float()
graph.ndata['feat'] = feat graph.ndata["feat"] = feat
diff_graph.edata['edge_weight'] = edge_weight diff_graph.edata["edge_weight"] = edge_weight
n_feat = feat.shape[1] n_feat = feat.shape[1]
n_classes = np.unique(label).shape[0] n_classes = np.unique(label).shape[0]
...@@ -67,13 +113,15 @@ if __name__ == '__main__': ...@@ -67,13 +113,15 @@ if __name__ == '__main__':
model = model.to(args.device) model = model.to(args.device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1) optimizer = th.optim.Adam(
model.parameters(), lr=args.lr1, weight_decay=args.wd1
)
loss_fn = nn.BCEWithLogitsLoss() loss_fn = nn.BCEWithLogitsLoss()
node_list = list(range(n_node)) node_list = list(range(n_node))
# Step 4: Training epochs ================================================================ # # Step 4: Training epochs ================================================================ #
best = float('inf') best = float("inf")
cnt_wait = 0 cnt_wait = 0
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
...@@ -84,8 +132,8 @@ if __name__ == '__main__': ...@@ -84,8 +132,8 @@ if __name__ == '__main__':
g = dgl.node_subgraph(graph, sample_idx) g = dgl.node_subgraph(graph, sample_idx)
dg = dgl.node_subgraph(diff_graph, sample_idx) dg = dgl.node_subgraph(diff_graph, sample_idx)
f = g.ndata.pop('feat') f = g.ndata.pop("feat")
ew = dg.edata.pop('edge_weight') ew = dg.edata.pop("edge_weight")
shuf_idx = np.random.permutation(sample_size) shuf_idx = np.random.permutation(sample_size)
sf = f[shuf_idx, :] sf = f[shuf_idx, :]
...@@ -103,20 +151,20 @@ if __name__ == '__main__': ...@@ -103,20 +151,20 @@ if __name__ == '__main__':
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item())) print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, loss.item()))
if loss < best: if loss < best:
best = loss best = loss
cnt_wait = 0 cnt_wait = 0
th.save(model.state_dict(), 'model.pkl') th.save(model.state_dict(), "model.pkl")
else: else:
cnt_wait += 1 cnt_wait += 1
if cnt_wait == args.patience: if cnt_wait == args.patience:
print('Early stopping') print("Early stopping")
break break
model.load_state_dict(th.load('model.pkl')) model.load_state_dict(th.load("model.pkl"))
graph = graph.to(args.device) graph = graph.to(args.device)
diff_graph = diff_graph.to(args.device) diff_graph = diff_graph.to(args.device)
...@@ -135,7 +183,9 @@ if __name__ == '__main__': ...@@ -135,7 +183,9 @@ if __name__ == '__main__':
# Step 5: Linear evaluation ========================================================== # # Step 5: Linear evaluation ========================================================== #
for _ in range(5): for _ in range(5):
model = LogReg(args.hid_dim, n_classes) model = LogReg(args.hid_dim, n_classes)
opt = th.optim.Adam(model.parameters(), lr=args.lr2, weight_decay=args.wd2) opt = th.optim.Adam(
model.parameters(), lr=args.lr2, weight_decay=args.wd2
)
model = model.to(args.device) model = model.to(args.device)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
......
...@@ -4,6 +4,7 @@ import torch.nn as nn ...@@ -4,6 +4,7 @@ import torch.nn as nn
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.glob import AvgPooling from dgl.nn.pytorch.glob import AvgPooling
class LogReg(nn.Module): class LogReg(nn.Module):
def __init__(self, hid_dim, n_classes): def __init__(self, hid_dim, n_classes):
super(LogReg, self).__init__() super(LogReg, self).__init__()
...@@ -36,13 +37,17 @@ class Discriminator(nn.Module): ...@@ -36,13 +37,17 @@ class Discriminator(nn.Module):
return logits return logits
class MVGRL(nn.Module):
class MVGRL(nn.Module):
def __init__(self, in_dim, out_dim): def __init__(self, in_dim, out_dim):
super(MVGRL, self).__init__() super(MVGRL, self).__init__()
self.encoder1 = GraphConv(in_dim, out_dim, norm='both', bias=True, activation=nn.PReLU()) self.encoder1 = GraphConv(
self.encoder2 = GraphConv(in_dim, out_dim, norm='none', bias=True, activation=nn.PReLU()) in_dim, out_dim, norm="both", bias=True, activation=nn.PReLU()
)
self.encoder2 = GraphConv(
in_dim, out_dim, norm="none", bias=True, activation=nn.PReLU()
)
self.pooling = AvgPooling() self.pooling = AvgPooling()
self.disc = Discriminator(out_dim) self.disc = Discriminator(out_dim)
...@@ -66,4 +71,4 @@ class MVGRL(nn.Module): ...@@ -66,4 +71,4 @@ class MVGRL(nn.Module):
out = self.disc(h1, h2, h3, h4, c1, c2) out = self.disc(h1, h2, h3, h4, c1, c2)
return out return out
\ No newline at end of file
import time import time
from dgl.sampling import node2vec_random_walk
from model import Node2vecModel from model import Node2vecModel
from utils import load_graph, parse_arguments from utils import load_graph, parse_arguments
from dgl.sampling import node2vec_random_walk
def time_randomwalk(graph, args): def time_randomwalk(graph, args):
""" """
...@@ -12,44 +14,50 @@ def time_randomwalk(graph, args): ...@@ -12,44 +14,50 @@ def time_randomwalk(graph, args):
start_time = time.time() start_time = time.time()
# default setting for testing # default setting for testing
params = {'p': 0.25, params = {"p": 0.25, "q": 4, "walk_length": 50}
'q': 4,
'walk_length': 50}
for i in range(args.runs): for i in range(args.runs):
node2vec_random_walk(graph, graph.nodes(), **params) node2vec_random_walk(graph, graph.nodes(), **params)
end_time = time.time() end_time = time.time()
cost_time_avg = (end_time-start_time)/args.runs cost_time_avg = (end_time - start_time) / args.runs
print("Run dataset {} {} trials, mean run time: {:.3f}s".format(args.dataset, args.runs, cost_time_avg)) print(
"Run dataset {} {} trials, mean run time: {:.3f}s".format(
args.dataset, args.runs, cost_time_avg
)
)
def train_node2vec(graph, eval_set, args): def train_node2vec(graph, eval_set, args):
""" """
Train node2vec model Train node2vec model
""" """
trainer = Node2vecModel(graph, trainer = Node2vecModel(
embedding_dim=args.embedding_dim, graph,
walk_length=args.walk_length, embedding_dim=args.embedding_dim,
p=args.p, walk_length=args.walk_length,
q=args.q, p=args.p,
num_walks=args.num_walks, q=args.q,
eval_set=eval_set, num_walks=args.num_walks,
eval_steps=1, eval_set=eval_set,
device=args.device) eval_steps=1,
device=args.device,
)
trainer.train(epochs=args.epochs, batch_size=args.batch_size, learning_rate=0.01) trainer.train(
epochs=args.epochs, batch_size=args.batch_size, learning_rate=0.01
)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_arguments() args = parse_arguments()
graph, eval_set = load_graph(args.dataset) graph, eval_set = load_graph(args.dataset)
if args.task == 'train': if args.task == "train":
print("Perform training node2vec model") print("Perform training node2vec model")
train_node2vec(graph, eval_set, args) train_node2vec(graph, eval_set, args)
elif args.task == 'time': elif args.task == "time":
print("Timing random walks") print("Timing random walks")
time_randomwalk(graph, args) time_randomwalk(graph, args)
else: else:
raise ValueError('Task type error!') raise ValueError("Task type error!")
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from dgl.sampling import node2vec_random_walk from dgl.sampling import node2vec_random_walk
...@@ -39,8 +40,19 @@ class Node2vec(nn.Module): ...@@ -39,8 +40,19 @@ class Node2vec(nn.Module):
If omitted, DGL assumes that the neighbors are picked uniformly. If omitted, DGL assumes that the neighbors are picked uniformly.
""" """
def __init__(self, g, embedding_dim, walk_length, p, q, num_walks=10, window_size=5, num_negatives=5, def __init__(
use_sparse=True, weight_name=None): self,
g,
embedding_dim,
walk_length,
p,
q,
num_walks=10,
window_size=5,
num_negatives=5,
use_sparse=True,
weight_name=None,
):
super(Node2vec, self).__init__() super(Node2vec, self).__init__()
assert walk_length >= window_size assert walk_length >= window_size
...@@ -75,13 +87,17 @@ class Node2vec(nn.Module): ...@@ -75,13 +87,17 @@ class Node2vec(nn.Module):
batch = batch.repeat(self.num_walks) batch = batch.repeat(self.num_walks)
# positive # positive
pos_traces = node2vec_random_walk(self.g, batch, self.p, self.q, self.walk_length, self.prob) pos_traces = node2vec_random_walk(
self.g, batch, self.p, self.q, self.walk_length, self.prob
)
pos_traces = pos_traces.unfold(1, self.window_size, 1) # rolling window pos_traces = pos_traces.unfold(1, self.window_size, 1) # rolling window
pos_traces = pos_traces.contiguous().view(-1, self.window_size) pos_traces = pos_traces.contiguous().view(-1, self.window_size)
# negative # negative
neg_batch = batch.repeat(self.num_negatives) neg_batch = batch.repeat(self.num_negatives)
neg_traces = torch.randint(self.N, (neg_batch.size(0), self.walk_length)) neg_traces = torch.randint(
self.N, (neg_batch.size(0), self.walk_length)
)
neg_traces = torch.cat([neg_batch.view(-1, 1), neg_traces], dim=-1) neg_traces = torch.cat([neg_batch.view(-1, 1), neg_traces], dim=-1)
neg_traces = neg_traces.unfold(1, self.window_size, 1) # rolling window neg_traces = neg_traces.unfold(1, self.window_size, 1) # rolling window
neg_traces = neg_traces.contiguous().view(-1, self.window_size) neg_traces = neg_traces.contiguous().view(-1, self.window_size)
...@@ -122,7 +138,10 @@ class Node2vec(nn.Module): ...@@ -122,7 +138,10 @@ class Node2vec(nn.Module):
e = 1e-15 e = 1e-15
# Positive # Positive
pos_start, pos_rest = pos_trace[:, 0], pos_trace[:, 1:].contiguous() # start node and following trace pos_start, pos_rest = (
pos_trace[:, 0],
pos_trace[:, 1:].contiguous(),
) # start node and following trace
w_start = self.embedding(pos_start).unsqueeze(dim=1) w_start = self.embedding(pos_start).unsqueeze(dim=1)
w_rest = self.embedding(pos_rest) w_rest = self.embedding(pos_rest)
pos_out = (w_start * w_rest).sum(dim=-1).view(-1) pos_out = (w_start * w_rest).sum(dim=-1).view(-1)
...@@ -154,7 +173,12 @@ class Node2vec(nn.Module): ...@@ -154,7 +173,12 @@ class Node2vec(nn.Module):
Node2vec training data loader Node2vec training data loader
""" """
return DataLoader(torch.arange(self.N), batch_size=batch_size, shuffle=True, collate_fn=self.sample) return DataLoader(
torch.arange(self.N),
batch_size=batch_size,
shuffle=True,
collate_fn=self.sample,
)
@torch.no_grad() @torch.no_grad()
def evaluate(self, x_train, y_train, x_val, y_val): def evaluate(self, x_train, y_train, x_val, y_val):
...@@ -166,7 +190,9 @@ class Node2vec(nn.Module): ...@@ -166,7 +190,9 @@ class Node2vec(nn.Module):
x_train, y_train = x_train.cpu().numpy(), y_train.cpu().numpy() x_train, y_train = x_train.cpu().numpy(), y_train.cpu().numpy()
x_val, y_val = x_val.cpu().numpy(), y_val.cpu().numpy() x_val, y_val = x_val.cpu().numpy(), y_val.cpu().numpy()
lr = LogisticRegression(solver='lbfgs', multi_class='auto', max_iter=150).fit(x_train, y_train) lr = LogisticRegression(
solver="lbfgs", multi_class="auto", max_iter=150
).fit(x_train, y_train)
return lr.score(x_val, y_val) return lr.score(x_val, y_val)
...@@ -213,26 +239,52 @@ class Node2vecModel(object): ...@@ -213,26 +239,52 @@ class Node2vecModel(object):
device, default 'cpu'. device, default 'cpu'.
""" """
def __init__(self, g, embedding_dim, walk_length, p=1.0, q=1.0, num_walks=1, window_size=5, def __init__(
num_negatives=5, use_sparse=True, weight_name=None, eval_set=None, eval_steps=-1, device='cpu'): self,
g,
self.model = Node2vec(g, embedding_dim, walk_length, p, q, num_walks, embedding_dim,
window_size, num_negatives, use_sparse, weight_name) walk_length,
p=1.0,
q=1.0,
num_walks=1,
window_size=5,
num_negatives=5,
use_sparse=True,
weight_name=None,
eval_set=None,
eval_steps=-1,
device="cpu",
):
self.model = Node2vec(
g,
embedding_dim,
walk_length,
p,
q,
num_walks,
window_size,
num_negatives,
use_sparse,
weight_name,
)
self.g = g self.g = g
self.use_sparse = use_sparse self.use_sparse = use_sparse
self.eval_steps = eval_steps self.eval_steps = eval_steps
self.eval_set = eval_set self.eval_set = eval_set
if device == 'cpu': if device == "cpu":
self.device = device self.device = device
else: else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = "cuda" if torch.cuda.is_available() else "cpu"
def _train_step(self, model, loader, optimizer, device): def _train_step(self, model, loader, optimizer, device):
model.train() model.train()
total_loss = 0 total_loss = 0
for pos_traces, neg_traces in loader: for pos_traces, neg_traces in loader:
pos_traces, neg_traces = pos_traces.to(device), neg_traces.to(device) pos_traces, neg_traces = pos_traces.to(device), neg_traces.to(
device
)
optimizer.zero_grad() optimizer.zero_grad()
loss = model.loss(pos_traces, neg_traces) loss = model.loss(pos_traces, neg_traces)
loss.backward() loss.backward()
...@@ -265,15 +317,23 @@ class Node2vecModel(object): ...@@ -265,15 +317,23 @@ class Node2vecModel(object):
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
loader = self.model.loader(batch_size) loader = self.model.loader(batch_size)
if self.use_sparse: if self.use_sparse:
optimizer = torch.optim.SparseAdam(list(self.model.parameters()), lr=learning_rate) optimizer = torch.optim.SparseAdam(
list(self.model.parameters()), lr=learning_rate
)
else: else:
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) optimizer = torch.optim.Adam(
self.model.parameters(), lr=learning_rate
)
for i in range(epochs): for i in range(epochs):
loss = self._train_step(self.model, loader, optimizer, self.device) loss = self._train_step(self.model, loader, optimizer, self.device)
if self.eval_steps > 0: if self.eval_steps > 0:
if epochs % self.eval_steps == 0: if epochs % self.eval_steps == 0:
acc = self._evaluate_step() acc = self._evaluate_step()
print("Epoch: {}, Train Loss: {:.4f}, Val Acc: {:.4f}".format(i, loss, acc)) print(
"Epoch: {}, Train Loss: {:.4f}, Val Acc: {:.4f}".format(
i, loss, acc
)
)
def embedding(self, nodes=None): def embedding(self, nodes=None):
""" """
......
import argparse import argparse
from dgl.data import CitationGraphDataset
from ogb.nodeproppred import *
from ogb.linkproppred import * from ogb.linkproppred import *
from ogb.nodeproppred import *
from dgl.data import CitationGraphDataset
def load_graph(name): def load_graph(name):
cite_graphs = ['cora', 'citeseer', 'pubmed'] cite_graphs = ["cora", "citeseer", "pubmed"]
if name in cite_graphs: if name in cite_graphs:
dataset = CitationGraphDataset(name) dataset = CitationGraphDataset(name)
graph = dataset[0] graph = dataset[0]
nodes = graph.nodes() nodes = graph.nodes()
y = graph.ndata['label'] y = graph.ndata["label"]
train_mask = graph.ndata['train_mask'] train_mask = graph.ndata["train_mask"]
val_mask = graph.ndata['test_mask'] val_mask = graph.ndata["test_mask"]
nodes_train, y_train = nodes[train_mask], y[train_mask] nodes_train, y_train = nodes[train_mask], y[train_mask]
nodes_val, y_val = nodes[val_mask], y[val_mask] nodes_val, y_val = nodes[val_mask], y[val_mask]
eval_set = [(nodes_train, y_train), (nodes_val, y_val)] eval_set = [(nodes_train, y_train), (nodes_val, y_val)]
elif name.startswith('ogbn'): elif name.startswith("ogbn"):
dataset = DglNodePropPredDataset(name) dataset = DglNodePropPredDataset(name)
graph, y = dataset[0] graph, y = dataset[0]
split_nodes = dataset.get_idx_split() split_nodes = dataset.get_idx_split()
nodes = graph.nodes() nodes = graph.nodes()
train_idx = split_nodes['train'] train_idx = split_nodes["train"]
val_idx = split_nodes['valid'] val_idx = split_nodes["valid"]
nodes_train, y_train = nodes[train_idx], y[train_idx] nodes_train, y_train = nodes[train_idx], y[train_idx]
nodes_val, y_val = nodes[val_idx], y[val_idx] nodes_val, y_val = nodes[val_idx], y[val_idx]
...@@ -44,19 +46,19 @@ def parse_arguments(): ...@@ -44,19 +46,19 @@ def parse_arguments():
""" """
Parse arguments Parse arguments
""" """
parser = argparse.ArgumentParser(description='Node2vec') parser = argparse.ArgumentParser(description="Node2vec")
parser.add_argument('--dataset', type=str, default='cora') parser.add_argument("--dataset", type=str, default="cora")
# 'train' for training node2vec model, 'time' for testing speed of random walk # 'train' for training node2vec model, 'time' for testing speed of random walk
parser.add_argument('--task', type=str, default='train') parser.add_argument("--task", type=str, default="train")
parser.add_argument('--runs', type=int, default=10) parser.add_argument("--runs", type=int, default=10)
parser.add_argument('--device', type=str, default='cpu') parser.add_argument("--device", type=str, default="cpu")
parser.add_argument('--embedding_dim', type=int, default=128) parser.add_argument("--embedding_dim", type=int, default=128)
parser.add_argument('--walk_length', type=int, default=50) parser.add_argument("--walk_length", type=int, default=50)
parser.add_argument('--p', type=float, default=0.25) parser.add_argument("--p", type=float, default=0.25)
parser.add_argument('--q', type=float, default=4.0) parser.add_argument("--q", type=float, default=4.0)
parser.add_argument('--num_walks', type=int, default=10) parser.add_argument("--num_walks", type=int, default=10)
parser.add_argument('--epochs', type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument('--batch_size', type=int, default=128) parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args() args = parser.parse_args()
......
import dgl import argparse
import time
from functools import partial from functools import partial
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm import tqdm
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
from sampler import ClusterIter, subgraph_collate_fn from sampler import ClusterIter, subgraph_collate_fn
from torch.utils.data import DataLoader
import dgl
import dgl.nn.pytorch as dglnn
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, def __init__(
in_feats, self,
num_heads, in_feats,
n_hidden, num_heads,
n_classes, n_hidden,
n_layers, n_classes,
activation, n_layers,
dropout=0.): activation,
dropout=0.0,
):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.num_heads = num_heads self.num_heads = num_heads
self.layers.append(dglnn.GATConv(in_feats, self.layers.append(
n_hidden, dglnn.GATConv(
num_heads=num_heads, in_feats,
feat_drop=dropout, n_hidden,
attn_drop=dropout, num_heads=num_heads,
activation=activation, feat_drop=dropout,
negative_slope=0.2)) attn_drop=dropout,
activation=activation,
negative_slope=0.2,
)
)
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.GATConv(n_hidden * num_heads, self.layers.append(
n_hidden, dglnn.GATConv(
num_heads=num_heads, n_hidden * num_heads,
feat_drop=dropout, n_hidden,
attn_drop=dropout, num_heads=num_heads,
activation=activation, feat_drop=dropout,
negative_slope=0.2)) attn_drop=dropout,
self.layers.append(dglnn.GATConv(n_hidden * num_heads, activation=activation,
n_classes, negative_slope=0.2,
num_heads=num_heads, )
feat_drop=dropout, )
attn_drop=dropout, self.layers.append(
activation=None, dglnn.GATConv(
negative_slope=0.2)) n_hidden * num_heads,
n_classes,
num_heads=num_heads,
feat_drop=dropout,
attn_drop=dropout,
activation=None,
negative_slope=0.2,
)
)
def forward(self, g, x): def forward(self, g, x):
h = x h = x
for l, conv in enumerate(self.layers): for l, conv in enumerate(self.layers):
...@@ -72,24 +88,35 @@ class GAT(nn.Module): ...@@ -72,24 +88,35 @@ class GAT(nn.Module):
num_heads = self.num_heads num_heads = self.num_heads
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
if l < self.n_layers - 1: if l < self.n_layers - 1:
y = th.zeros(g.num_nodes(), self.n_hidden * num_heads if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.num_nodes(),
self.n_hidden * num_heads
if l != len(self.layers) - 1
else self.n_classes,
)
else: else:
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.num_nodes(),
self.n_hidden
if l != len(self.layers) - 1
else self.n_classes,
)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, g,
th.arange(g.num_nodes()), th.arange(g.num_nodes()),
sampler, sampler,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].int().to(device) block = blocks[0].int().to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
if l < self.n_layers - 1: if l < self.n_layers - 1:
h = layer(block, h).flatten(1) h = layer(block, h).flatten(1)
else: else:
h = layer(block, h) h = layer(block, h)
h = h.mean(1) h = h.mean(1)
...@@ -99,12 +126,14 @@ class GAT(nn.Module): ...@@ -99,12 +126,14 @@ class GAT(nn.Module):
x = y x = y
return y return y
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device): def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -119,22 +148,45 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device): ...@@ -119,22 +148,45 @@ def evaluate(model, g, nfeat, labels, val_nid, test_nid, batch_size, device):
with th.no_grad(): with th.no_grad():
pred = model.inference(g, nfeat, batch_size, device) pred = model.inference(g, nfeat, batch_size, device)
model.train() model.train()
labels_cpu = labels.to(th.device('cpu')) labels_cpu = labels.to(th.device("cpu"))
return compute_acc(pred[val_nid], labels_cpu[val_nid]), compute_acc(pred[test_nid], labels_cpu[test_nid]), pred return (
compute_acc(pred[val_nid], labels_cpu[val_nid]),
compute_acc(pred[test_nid], labels_cpu[test_nid]),
pred,
)
def model_param_summary(model): def model_param_summary(model):
""" Count the model parameters """ """Count the model parameters"""
cnt = sum(p.numel() for p in model.parameters() if p.requires_grad) cnt = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total Params {}".format(cnt)) print("Total Params {}".format(cnt))
#### Entry point #### Entry point
def run(args, device, data, nfeat): def run(args, device, data, nfeat):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, cluster_iterator = data (
train_nid,
val_nid,
test_nid,
in_feats,
labels,
n_classes,
g,
cluster_iterator,
) = data
labels = labels.to(device) labels = labels.to(device)
# Define model and optimizer # Define model and optimizer
model = GAT(in_feats, args.num_heads, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = GAT(
in_feats,
args.num_heads,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model_param_summary(model) model_param_summary(model)
model = model.to(device) model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
...@@ -153,7 +205,7 @@ def run(args, device, data, nfeat): ...@@ -153,7 +205,7 @@ def run(args, device, data, nfeat):
# blocks. # blocks.
tic_start = time.time() tic_start = time.time()
for step, cluster in enumerate(cluster_iterator): for step, cluster in enumerate(cluster_iterator):
mask = cluster.ndata.pop('train_mask') mask = cluster.ndata.pop("train_mask")
if mask.sum() == 0: if mask.sum() == 0:
continue continue
cluster.edata.pop(dgl.EID) cluster.edata.pop(dgl.EID)
...@@ -173,99 +225,156 @@ def run(args, device, data, nfeat): ...@@ -173,99 +225,156 @@ def run(args, device, data, nfeat):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
tic_back = time.time() tic_back = time.time()
iter_load += (tic_step - tic_start) iter_load += tic_step - tic_start
iter_far += (tic_far - tic_step) iter_far += tic_far - tic_step
iter_back += (tic_back - tic_far) iter_back += tic_back - tic_far
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB'.format( th.cuda.max_memory_allocated() / 1000000
epoch, step, loss.item(), acc.item(), gpu_mem_alloc)) if th.cuda.is_available()
else 0
)
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB".format(
epoch, step, loss.item(), acc.item(), gpu_mem_alloc
)
)
tic_start = time.time() tic_start = time.time()
toc = time.time() toc = time.time()
print('Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}'.format(toc - tic, iter_load, iter_far, iter_back)) print(
"Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}".format(
toc - tic, iter_load, iter_far, iter_back
)
)
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(model, g, nfeat, labels, val_nid, test_nid, args.val_batch_size, device) eval_acc, test_acc, pred = evaluate(
model,
g,
nfeat,
labels,
val_nid,
test_nid,
args.val_batch_size,
device,
)
model = model.to(device) model = model.to(device)
if args.save_pred: if args.save_pred:
np.savetxt(args.save_pred + '%02d' % epoch, pred.argmax(1).cpu().numpy(), '%d') np.savetxt(
print('Eval Acc {:.4f}'.format(eval_acc)) args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc: if eval_acc > best_eval_acc:
best_eval_acc = eval_acc best_eval_acc = eval_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc)) print(
print('Avg epoch time: {}'.format(avg / (epoch - 4))) "Best Eval Acc {:.4f} Test Acc {:.4f}".format(
return best_test_acc.to(th.device('cpu')) best_eval_acc, best_test_acc
)
)
print("Avg epoch time: {}".format(avg / (epoch - 4)))
return best_test_acc.to(th.device("cpu"))
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument(
help="GPU device ID. Use -1 for CPU training") "--gpu",
argparser.add_argument('--num-epochs', type=int, default=20) type=int,
argparser.add_argument('--num-hidden', type=int, default=128) default=0,
argparser.add_argument('--num-layers', type=int, default=3) help="GPU device ID. Use -1 for CPU training",
argparser.add_argument('--num-heads', type=int, default=8) )
argparser.add_argument('--batch-size', type=int, default=32) argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument('--val-batch-size', type=int, default=2000) argparser.add_argument("--num-hidden", type=int, default=128)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument('--eval-every', type=int, default=1) argparser.add_argument("--num-heads", type=int, default=8)
argparser.add_argument('--lr', type=float, default=0.001) argparser.add_argument("--batch-size", type=int, default=32)
argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument("--val-batch-size", type=int, default=2000)
argparser.add_argument('--save-pred', type=str, default='') argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument('--wd', type=float, default=0) argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument('--num_partitions', type=int, default=15000) argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument('--num-workers', type=int, default=0) argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument('--data-cpu', action='store_true', argparser.add_argument("--save-pred", type=str, default="")
help="By default the script puts all node features and labels " argparser.add_argument("--wd", type=float, default=0)
"on GPU when using it to save time for data copy. This may " argparser.add_argument("--num_partitions", type=int, default=15000)
"be undesired if they cannot fit in GPU memory at once. " argparser.add_argument("--num-workers", type=int, default=0)
"This flag disables that.") argparser.add_argument(
"--data-cpu",
action="store_true",
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.",
)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu) device = th.device("cuda:%d" % args.gpu)
else: else:
device = th.device('cpu') device = th.device("cpu")
# load ogbn-products data # load ogbn-products data
data = DglNodePropPredDataset(name='ogbn-products') data = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test'] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
labels = labels[:, 0] labels = labels[:, 0]
print('Total edges before adding self-loop {}'.format(graph.num_edges())) print("Total edges before adding self-loop {}".format(graph.num_edges()))
graph = dgl.remove_self_loop(graph) graph = dgl.remove_self_loop(graph)
graph = dgl.add_self_loop(graph) graph = dgl.add_self_loop(graph)
print('Total edges after adding self-loop {}'.format(graph.num_edges())) print("Total edges after adding self-loop {}".format(graph.num_edges()))
num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0] num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]
assert num_nodes == graph.num_nodes() assert num_nodes == graph.num_nodes()
mask = th.zeros(num_nodes, dtype=th.bool) mask = th.zeros(num_nodes, dtype=th.bool)
mask[train_idx] = True mask[train_idx] = True
graph.ndata['train_mask'] = mask graph.ndata["train_mask"] = mask
graph.in_degrees(0) graph.in_degrees(0)
graph.out_degrees(0) graph.out_degrees(0)
graph.find_edges(0) graph.find_edges(0)
cluster_iter_data = ClusterIter( cluster_iter_data = ClusterIter(
'ogbn-products', graph, args.num_partitions, args.batch_size) "ogbn-products", graph, args.num_partitions, args.batch_size
cluster_iterator = DataLoader(cluster_iter_data, batch_size=args.batch_size, shuffle=True, )
pin_memory=True, num_workers=4, cluster_iterator = DataLoader(
collate_fn=partial(subgraph_collate_fn, graph)) cluster_iter_data,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4,
collate_fn=partial(subgraph_collate_fn, graph),
)
in_feats = graph.ndata['feat'].shape[1] in_feats = graph.ndata["feat"].shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, cluster_iterator data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
graph,
cluster_iterator,
)
# Run 10 times # Run 10 times
test_accs = [] test_accs = []
nfeat = graph.ndata.pop('feat').to(device) nfeat = graph.ndata.pop("feat").to(device)
for i in range(10): for i in range(10):
test_accs.append(run(args, device, data, nfeat)) test_accs.append(run(args, device, data, nfeat))
print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs)) print(
"Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs)
)
...@@ -3,8 +3,9 @@ from time import time ...@@ -3,8 +3,9 @@ from time import time
import numpy as np import numpy as np
import dgl import dgl
from dgl.transforms import metis_partition
from dgl import backend as F from dgl import backend as F
from dgl.transforms import metis_partition
def get_partition_list(g, psize): def get_partition_list(g, psize):
p_gs = metis_partition(g, psize) p_gs = metis_partition(g, psize)
......
import os import os
import torch import torch
from partition_utils import * from partition_utils import *
class ClusterIter(object): class ClusterIter(object):
'''The partition sampler given a DGLGraph and partition number. """The partition sampler given a DGLGraph and partition number.
The metis is used as the graph partition backend. The metis is used as the graph partition backend.
''' """
def __init__(self, dn, g, psize, batch_size): def __init__(self, dn, g, psize, batch_size):
"""Initialize the sampler. """Initialize the sampler.
...@@ -26,11 +27,11 @@ class ClusterIter(object): ...@@ -26,11 +27,11 @@ class ClusterIter(object):
self.batch_size = batch_size self.batch_size = batch_size
# cache the partitions of known datasets&partition number # cache the partitions of known datasets&partition number
if dn: if dn:
fn = os.path.join('./datasets/', dn + '_{}.npy'.format(psize)) fn = os.path.join("./datasets/", dn + "_{}.npy".format(psize))
if os.path.exists(fn): if os.path.exists(fn):
self.par_li = np.load(fn, allow_pickle=True) self.par_li = np.load(fn, allow_pickle=True)
else: else:
os.makedirs('./datasets/', exist_ok=True) os.makedirs("./datasets/", exist_ok=True)
self.par_li = get_partition_list(g, psize) self.par_li = get_partition_list(g, psize)
np.save(fn, self.par_li) np.save(fn, self.par_li)
else: else:
...@@ -47,6 +48,7 @@ class ClusterIter(object): ...@@ -47,6 +48,7 @@ class ClusterIter(object):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.par_li[idx] return self.par_li[idx]
def subgraph_collate_fn(g, batch): def subgraph_collate_fn(g, batch):
nids = np.concatenate(batch).reshape(-1).astype(np.int64) nids = np.concatenate(batch).reshape(-1).astype(np.int64)
g1 = g.subgraph(nids) g1 = g.subgraph(nids)
......
import dgl import argparse
import time
import traceback
from functools import partial
import numpy as np import numpy as np
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
from sampler import ClusterIter, subgraph_collate_fn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time
import argparse
from dgl.data import RedditDataset from dgl.data import RedditDataset
import tqdm
import traceback
from ogb.nodeproppred import DglNodePropPredDataset
from functools import partial
from sampler import ClusterIter, subgraph_collate_fn
#### Neighbor sampler #### Neighbor sampler
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
n_hidden, ):
n_classes,
n_layers,
activation,
dropout):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
...@@ -70,12 +68,14 @@ class SAGE(nn.Module): ...@@ -70,12 +68,14 @@ class SAGE(nn.Module):
return h return h
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
""" """
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, labels, val_nid, test_nid, batch_size, device): def evaluate(model, g, labels, val_nid, test_nid, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_mask``. Evaluate the model on the validation set specified by ``val_mask``.
...@@ -88,28 +88,49 @@ def evaluate(model, g, labels, val_nid, test_nid, batch_size, device): ...@@ -88,28 +88,49 @@ def evaluate(model, g, labels, val_nid, test_nid, batch_size, device):
""" """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
inputs = g.ndata['feat'] inputs = g.ndata["feat"]
model = model.cpu() model = model.cpu()
pred = model.inference(g, inputs, batch_size, device) pred = model.inference(g, inputs, batch_size, device)
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]), pred return (
compute_acc(pred[val_nid], labels[val_nid]),
compute_acc(pred[test_nid], labels[test_nid]),
pred,
)
def load_subtensor(g, labels, seeds, input_nodes, device): def load_subtensor(g, labels, seeds, input_nodes, device):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
batch_inputs = g.ndata['feat'][input_nodes].to(device) batch_inputs = g.ndata["feat"][input_nodes].to(device)
batch_labels = labels[seeds].to(device) batch_labels = labels[seeds].to(device)
return batch_inputs, batch_labels return batch_inputs, batch_labels
#### Entry point #### Entry point
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, cluster_iterator = data (
train_nid,
val_nid,
test_nid,
in_feats,
labels,
n_classes,
g,
cluster_iterator,
) = data
# Define model and optimizer # Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = SAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
...@@ -132,11 +153,11 @@ def run(args, device, data): ...@@ -132,11 +153,11 @@ def run(args, device, data):
tic_start = time.time() tic_start = time.time()
for step, cluster in enumerate(cluster_iterator): for step, cluster in enumerate(cluster_iterator):
cluster = cluster.int().to(device) cluster = cluster.int().to(device)
mask = cluster.ndata['train_mask'].to(device) mask = cluster.ndata["train_mask"].to(device)
if mask.sum() == 0: if mask.sum() == 0:
continue continue
feat = cluster.ndata['feat'].to(device) feat = cluster.ndata["feat"].to(device)
batch_labels = cluster.ndata['labels'].to(device) batch_labels = cluster.ndata["labels"].to(device)
tic_step = time.time() tic_step = time.time()
batch_pred = model(cluster, feat) batch_pred = model(cluster, feat)
...@@ -148,94 +169,147 @@ def run(args, device, data): ...@@ -148,94 +169,147 @@ def run(args, device, data):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
tic_back = time.time() tic_back = time.time()
iter_load += (tic_step - tic_start) iter_load += tic_step - tic_start
iter_far += (tic_far - tic_step) iter_far += tic_far - tic_step
iter_back += (tic_back - tic_far) iter_back += tic_back - tic_far
tic_start = time.time() tic_start = time.time()
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB'.format( th.cuda.max_memory_allocated() / 1000000
epoch, step, loss.item(), acc.item(), gpu_mem_alloc)) if th.cuda.is_available()
else 0
)
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB".format(
epoch, step, loss.item(), acc.item(), gpu_mem_alloc
)
)
toc = time.time() toc = time.time()
print('Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}'.format(toc - tic, iter_load, iter_far, iter_back)) print(
"Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}".format(
toc - tic, iter_load, iter_far, iter_back
)
)
if epoch >= 5: if epoch >= 5:
avg += toc - tic avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
eval_acc, test_acc, pred = evaluate(model, g, labels, val_nid, test_nid, args.val_batch_size, device) eval_acc, test_acc, pred = evaluate(
model, g, labels, val_nid, test_nid, args.val_batch_size, device
)
model = model.to(device) model = model.to(device)
if args.save_pred: if args.save_pred:
np.savetxt(args.save_pred + '%02d' % epoch, pred.argmax(1).cpu().numpy(), '%d') np.savetxt(
print('Eval Acc {:.4f}'.format(eval_acc)) args.save_pred + "%02d" % epoch,
pred.argmax(1).cpu().numpy(),
"%d",
)
print("Eval Acc {:.4f}".format(eval_acc))
if eval_acc > best_eval_acc: if eval_acc > best_eval_acc:
best_eval_acc = eval_acc best_eval_acc = eval_acc
best_test_acc = test_acc best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc)) print(
print('Avg epoch time: {}'.format(avg / (epoch - 4))) "Best Eval Acc {:.4f} Test Acc {:.4f}".format(
best_eval_acc, best_test_acc
)
)
print("Avg epoch time: {}".format(avg / (epoch - 4)))
return best_test_acc return best_test_acc
if __name__ == '__main__':
if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0, argparser.add_argument(
help="GPU device ID. Use -1 for CPU training") "--gpu",
argparser.add_argument('--num-epochs', type=int, default=30) type=int,
argparser.add_argument('--num-hidden', type=int, default=256) default=0,
argparser.add_argument('--num-layers', type=int, default=3) help="GPU device ID. Use -1 for CPU training",
argparser.add_argument('--batch-size', type=int, default=32) )
argparser.add_argument('--val-batch-size', type=int, default=10000) argparser.add_argument("--num-epochs", type=int, default=30)
argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument("--num-hidden", type=int, default=256)
argparser.add_argument('--eval-every', type=int, default=1) argparser.add_argument("--num-layers", type=int, default=3)
argparser.add_argument('--lr', type=float, default=0.001) argparser.add_argument("--batch-size", type=int, default=32)
argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument("--val-batch-size", type=int, default=10000)
argparser.add_argument('--save-pred', type=str, default='') argparser.add_argument("--log-every", type=int, default=20)
argparser.add_argument('--wd', type=float, default=0) argparser.add_argument("--eval-every", type=int, default=1)
argparser.add_argument('--num_partitions', type=int, default=15000) argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument("--save-pred", type=str, default="")
argparser.add_argument("--wd", type=float, default=0)
argparser.add_argument("--num_partitions", type=int, default=15000)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu) device = th.device("cuda:%d" % args.gpu)
else: else:
device = th.device('cpu') device = th.device("cpu")
# load ogbn-products data # load ogbn-products data
data = DglNodePropPredDataset(name='ogbn-products') data = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
train_idx, val_idx, test_idx = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test'] train_idx, val_idx, test_idx = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
graph, labels = data[0] graph, labels = data[0]
labels = labels[:, 0] labels = labels[:, 0]
num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0] num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]
assert num_nodes == graph.number_of_nodes() assert num_nodes == graph.number_of_nodes()
graph.ndata['labels'] = labels graph.ndata["labels"] = labels
mask = th.zeros(num_nodes, dtype=th.bool) mask = th.zeros(num_nodes, dtype=th.bool)
mask[train_idx] = True mask[train_idx] = True
graph.ndata['train_mask'] = mask graph.ndata["train_mask"] = mask
mask = th.zeros(num_nodes, dtype=th.bool) mask = th.zeros(num_nodes, dtype=th.bool)
mask[val_idx] = True mask[val_idx] = True
graph.ndata['valid_mask'] = mask graph.ndata["valid_mask"] = mask
mask = th.zeros(num_nodes, dtype=th.bool) mask = th.zeros(num_nodes, dtype=th.bool)
mask[test_idx] = True mask[test_idx] = True
graph.ndata['test_mask'] = mask graph.ndata["test_mask"] = mask
graph.in_degree(0) graph.in_degree(0)
graph.out_degree(0) graph.out_degree(0)
graph.find_edges(0) graph.find_edges(0)
cluster_iter_data = ClusterIter( cluster_iter_data = ClusterIter(
'ogbn-products', graph, args.num_partitions, args.batch_size, th.cat([train_idx, val_idx, test_idx])) "ogbn-products",
graph,
args.num_partitions,
args.batch_size,
th.cat([train_idx, val_idx, test_idx]),
)
idx = th.arange(args.num_partitions // args.batch_size) idx = th.arange(args.num_partitions // args.batch_size)
cluster_iterator = DataLoader(cluster_iter_data, batch_size=32, shuffle=True, pin_memory=True, num_workers=4, collate_fn=partial(subgraph_collate_fn, graph)) cluster_iterator = DataLoader(
cluster_iter_data,
batch_size=32,
shuffle=True,
pin_memory=True,
num_workers=4,
collate_fn=partial(subgraph_collate_fn, graph),
)
in_feats = graph.ndata['feat'].shape[1] in_feats = graph.ndata["feat"].shape[1]
print(in_feats) print(in_feats)
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, cluster_iterator data = (
train_idx,
val_idx,
test_idx,
in_feats,
labels,
n_classes,
graph,
cluster_iterator,
)
# Run 10 times # Run 10 times
test_accs = [] test_accs = []
for i in range(10): for i in range(10):
test_accs.append(run(args, device, data)) test_accs.append(run(args, device, data))
print('Average test accuracy:', np.mean(test_accs), '±', np.std(test_accs)) print(
"Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs)
)
...@@ -3,8 +3,9 @@ from time import time ...@@ -3,8 +3,9 @@ from time import time
import numpy as np import numpy as np
import dgl import dgl
from dgl.transforms import metis_partition
from dgl import backend as F from dgl import backend as F
from dgl.transforms import metis_partition
def get_partition_list(g, psize): def get_partition_list(g, psize):
p_gs = metis_partition(g, psize) p_gs = metis_partition(g, psize)
......
import os import os
import random import random
import dgl.function as fn
import torch
import time import time
import torch
from partition_utils import * from partition_utils import *
import dgl.function as fn
class ClusterIter(object): class ClusterIter(object):
'''The partition sampler given a DGLGraph and partition number. """The partition sampler given a DGLGraph and partition number.
The metis is used as the graph partition backend. The metis is used as the graph partition backend.
''' """
def __init__(self, dn, g, psize, batch_size, seed_nid): def __init__(self, dn, g, psize, batch_size, seed_nid):
"""Initialize the sampler. """Initialize the sampler.
...@@ -32,11 +33,11 @@ class ClusterIter(object): ...@@ -32,11 +33,11 @@ class ClusterIter(object):
self.batch_size = batch_size self.batch_size = batch_size
# cache the partitions of known datasets&partition number # cache the partitions of known datasets&partition number
if dn: if dn:
fn = os.path.join('./datasets/', dn + '_{}.npy'.format(psize)) fn = os.path.join("./datasets/", dn + "_{}.npy".format(psize))
if os.path.exists(fn): if os.path.exists(fn):
self.par_li = np.load(fn, allow_pickle=True) self.par_li = np.load(fn, allow_pickle=True)
else: else:
os.makedirs('./datasets/', exist_ok=True) os.makedirs("./datasets/", exist_ok=True)
self.par_li = get_partition_list(g, psize) self.par_li = get_partition_list(g, psize)
np.save(fn, self.par_li) np.save(fn, self.par_li)
else: else:
...@@ -49,9 +50,9 @@ class ClusterIter(object): ...@@ -49,9 +50,9 @@ class ClusterIter(object):
# use one side normalization # use one side normalization
def get_norm(self, g): def get_norm(self, g):
norm = 1. / g.in_degrees().float().unsqueeze(1) norm = 1.0 / g.in_degrees().float().unsqueeze(1)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
norm = norm.to(self.g.ndata['feat'].device) norm = norm.to(self.g.ndata["feat"].device)
return norm return norm
def __len__(self): def __len__(self):
...@@ -60,6 +61,7 @@ class ClusterIter(object): ...@@ -60,6 +61,7 @@ class ClusterIter(object):
def __getitem__(self, idx): def __getitem__(self, idx):
return self.par_li[idx] return self.par_li[idx]
def subgraph_collate_fn(g, batch): def subgraph_collate_fn(g, batch):
nids = np.concatenate(batch).reshape(-1).astype(np.int64) nids = np.concatenate(batch).reshape(-1).astype(np.int64)
g1 = g.subgraph(nids) g1 = g.subgraph(nids)
......
import torch
import argparse import argparse
import dgl
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os import os
import random import random
import time import time
import numpy as np
from reading_data import DeepwalkDataset import numpy as np
import torch
import torch.multiprocessing as mp
from model import SkipGramModel from model import SkipGramModel
from reading_data import DeepwalkDataset
from torch.utils.data import DataLoader
from utils import shuffle_walks, sum_up_params from utils import shuffle_walks, sum_up_params
import dgl
class DeepwalkTrainer: class DeepwalkTrainer:
def __init__(self, args): def __init__(self, args):
""" Initializing the trainer with the input arguments """ """Initializing the trainer with the input arguments"""
self.args = args self.args = args
self.dataset = DeepwalkDataset( self.dataset = DeepwalkDataset(
net_file=args.data_file, net_file=args.data_file,
...@@ -28,20 +30,22 @@ class DeepwalkTrainer: ...@@ -28,20 +30,22 @@ class DeepwalkTrainer:
fast_neg=args.fast_neg, fast_neg=args.fast_neg,
ogbl_name=args.ogbl_name, ogbl_name=args.ogbl_name,
load_from_ogbl=args.load_from_ogbl, load_from_ogbl=args.load_from_ogbl,
) )
self.emb_size = self.dataset.G.number_of_nodes() self.emb_size = self.dataset.G.number_of_nodes()
self.emb_model = None self.emb_model = None
def init_device_emb(self): def init_device_emb(self):
""" set the device before training """set the device before training
will be called once in fast_train_mp / fast_train will be called once in fast_train_mp / fast_train
""" """
choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix]) choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]" assert (
choices == 1
), "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
# initializing embedding on CPU # initializing embedding on CPU
self.emb_model = SkipGramModel( self.emb_model = SkipGramModel(
emb_size=self.emb_size, emb_size=self.emb_size,
emb_dimension=self.args.dim, emb_dimension=self.args.dim,
walk_length=self.args.walk_length, walk_length=self.args.walk_length,
window_size=self.args.window_size, window_size=self.args.window_size,
...@@ -59,8 +63,8 @@ class DeepwalkTrainer: ...@@ -59,8 +63,8 @@ class DeepwalkTrainer:
use_context_weight=self.args.use_context_weight, use_context_weight=self.args.use_context_weight,
async_update=self.args.async_update, async_update=self.args.async_update,
num_threads=self.args.num_threads, num_threads=self.args.num_threads,
) )
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.only_gpu: if self.args.only_gpu:
print("Run in 1 GPU") print("Run in 1 GPU")
...@@ -69,22 +73,23 @@ class DeepwalkTrainer: ...@@ -69,22 +73,23 @@ class DeepwalkTrainer:
elif self.args.mix: elif self.args.mix:
print("Mix CPU with %d GPU" % len(self.args.gpus)) print("Mix CPU with %d GPU" % len(self.args.gpus))
if len(self.args.gpus) == 1: if len(self.args.gpus) == 1:
assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU' assert (
self.args.gpus[0] >= 0
), "mix CPU with GPU should have available GPU"
self.emb_model.set_device(self.args.gpus[0]) self.emb_model.set_device(self.args.gpus[0])
else: else:
print("Run in CPU process") print("Run in CPU process")
self.args.gpus = [torch.device('cpu')] self.args.gpus = [torch.device("cpu")]
def train(self): def train(self):
""" train the embedding """ """train the embedding"""
if len(self.args.gpus) > 1: if len(self.args.gpus) > 1:
self.fast_train_mp() self.fast_train_mp()
else: else:
self.fast_train() self.fast_train()
def fast_train_mp(self): def fast_train_mp(self):
""" multi-cpu-core or mix cpu & multi-gpu """ """multi-cpu-core or mix cpu & multi-gpu"""
self.init_device_emb() self.init_device_emb()
self.emb_model.share_memory() self.emb_model.share_memory()
...@@ -95,26 +100,34 @@ class DeepwalkTrainer: ...@@ -95,26 +100,34 @@ class DeepwalkTrainer:
ps = [] ps = []
for i in range(len(self.args.gpus)): for i in range(len(self.args.gpus)):
p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i])) p = mp.Process(
target=self.fast_train_sp, args=(i, self.args.gpus[i])
)
ps.append(p) ps.append(p)
p.start() p.start()
for p in ps: for p in ps:
p.join() p.join()
print("Used time: %.2fs" % (time.time()-start_all)) print("Used time: %.2fs" % (time.time() - start_all))
if self.args.save_in_txt: if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_txt(
self.dataset, self.args.output_emb_file
)
elif self.args.save_in_pt: elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_pt(
self.dataset, self.args.output_emb_file
)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(
self.dataset, self.args.output_emb_file
)
def fast_train_sp(self, rank, gpu_id): def fast_train_sp(self, rank, gpu_id):
""" a subprocess for fast_train_mp """ """a subprocess for fast_train_mp"""
if self.args.mix: if self.args.mix:
self.emb_model.set_device(gpu_id) self.emb_model.set_device(gpu_id)
torch.set_num_threads(self.args.num_threads) torch.set_num_threads(self.args.num_threads)
if self.args.async_update: if self.args.async_update:
self.emb_model.create_async_update() self.emb_model.create_async_update()
...@@ -128,13 +141,18 @@ class DeepwalkTrainer: ...@@ -128,13 +141,18 @@ class DeepwalkTrainer:
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=self.args.num_sampler_threads, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id)) print(
"num batchs: %d in process [%d] GPU [%d]"
% (num_batches, rank, gpu_id)
)
# number of positive node pairs in a sequence # number of positive node pairs in a sequence
num_pos = int(2 * self.args.walk_length * self.args.window_size\ num_pos = int(
- self.args.window_size * (self.args.window_size + 1)) 2 * self.args.walk_length * self.args.window_size
- self.args.window_size * (self.args.window_size + 1)
)
start = time.time() start = time.time()
with torch.no_grad(): with torch.no_grad():
for i, walks in enumerate(dataloader): for i, walks in enumerate(dataloader):
...@@ -144,28 +162,44 @@ class DeepwalkTrainer: ...@@ -144,28 +162,44 @@ class DeepwalkTrainer:
# do negative sampling # do negative sampling
bs = len(walks) bs = len(walks)
neg_nodes = torch.LongTensor( neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table, np.random.choice(
bs * num_pos * self.args.negative, self.dataset.neg_table,
replace=True)) bs * num_pos * self.args.negative,
replace=True,
)
)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes) self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss: if self.args.print_loss:
print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \ print(
% (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval)) "GPU-[%d] batch %d time: %.2fs loss: %.4f"
% (
gpu_id,
i,
time.time() - start,
-sum(self.emb_model.loss)
/ self.args.print_interval,
)
)
self.emb_model.loss = [] self.emb_model.loss = []
else: else:
print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start)) print(
"GPU-[%d] batch %d time: %.2fs"
% (gpu_id, i, time.time() - start)
)
start = time.time() start = time.time()
if self.args.async_update: if self.args.async_update:
self.emb_model.finish_async_update() self.emb_model.finish_async_update()
def fast_train(self): def fast_train(self):
""" fast train with dataloader with only gpu / only cpu""" """fast train with dataloader with only gpu / only cpu"""
# the number of postive node pairs of a node sequence # the number of postive node pairs of a node sequence
num_pos = 2 * self.args.walk_length * self.args.window_size\ num_pos = (
2 * self.args.walk_length * self.args.window_size
- self.args.window_size * (self.args.window_size + 1) - self.args.window_size * (self.args.window_size + 1)
)
num_pos = int(num_pos) num_pos = int(num_pos)
self.init_device_emb() self.init_device_emb()
...@@ -186,8 +220,8 @@ class DeepwalkTrainer: ...@@ -186,8 +220,8 @@ class DeepwalkTrainer:
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=self.args.num_sampler_threads, num_workers=self.args.num_sampler_threads,
) )
num_batches = len(dataloader) num_batches = len(dataloader)
print("num batchs: %d\n" % num_batches) print("num batchs: %d\n" % num_batches)
...@@ -202,109 +236,228 @@ class DeepwalkTrainer: ...@@ -202,109 +236,228 @@ class DeepwalkTrainer:
# do negative sampling # do negative sampling
bs = len(walks) bs = len(walks)
neg_nodes = torch.LongTensor( neg_nodes = torch.LongTensor(
np.random.choice(self.dataset.neg_table, np.random.choice(
bs * num_pos * self.args.negative, self.dataset.neg_table,
replace=True)) bs * num_pos * self.args.negative,
replace=True,
)
)
self.emb_model.fast_learn(walks, neg_nodes=neg_nodes) self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
if i > 0 and i % self.args.print_interval == 0: if i > 0 and i % self.args.print_interval == 0:
if self.args.print_loss: if self.args.print_loss:
print("Batch %d training time: %.2fs loss: %.4f" \ print(
% (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval)) "Batch %d training time: %.2fs loss: %.4f"
% (
i,
time.time() - start,
-sum(self.emb_model.loss)
/ self.args.print_interval,
)
)
self.emb_model.loss = [] self.emb_model.loss = []
else: else:
print("Batch %d, training time: %.2fs" % (i, time.time()-start)) print(
"Batch %d, training time: %.2fs"
% (i, time.time() - start)
)
start = time.time() start = time.time()
if self.args.async_update: if self.args.async_update:
self.emb_model.finish_async_update() self.emb_model.finish_async_update()
print("Training used time: %.2fs" % (time.time()-start_all)) print("Training used time: %.2fs" % (time.time() - start_all))
if self.args.save_in_txt: if self.args.save_in_txt:
self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_txt(
self.dataset, self.args.output_emb_file
)
elif self.args.save_in_pt: elif self.args.save_in_pt:
self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding_pt(
self.dataset, self.args.output_emb_file
)
else: else:
self.emb_model.save_embedding(self.dataset, self.args.output_emb_file) self.emb_model.save_embedding(
self.dataset, self.args.output_emb_file
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="DeepWalk") parser = argparse.ArgumentParser(description="DeepWalk")
# input files # input files
## personal datasets ## personal datasets
parser.add_argument('--data_file', type=str, parser.add_argument(
help="path of the txt network file, builtin dataset include youtube-net and blog-net") "--data_file",
type=str,
help="path of the txt network file, builtin dataset include youtube-net and blog-net",
)
## ogbl datasets ## ogbl datasets
parser.add_argument('--ogbl_name', type=str, parser.add_argument(
help="name of ogbl dataset, e.g. ogbl-ddi") "--ogbl_name", type=str, help="name of ogbl dataset, e.g. ogbl-ddi"
parser.add_argument('--load_from_ogbl', default=False, action="store_true", )
help="whether load dataset from ogbl") parser.add_argument(
"--load_from_ogbl",
default=False,
action="store_true",
help="whether load dataset from ogbl",
)
# output files # output files
parser.add_argument('--save_in_txt', default=False, action="store_true", parser.add_argument(
help='Whether save dat in txt format or npy') "--save_in_txt",
parser.add_argument('--save_in_pt', default=False, action="store_true", default=False,
help='Whether save dat in pt format or npy') action="store_true",
parser.add_argument('--output_emb_file', type=str, default="emb.npy", help="Whether save dat in txt format or npy",
help='path of the output npy embedding file') )
parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle", parser.add_argument(
help='path of the mapping dict that maps node ids to embedding index') "--save_in_pt",
parser.add_argument('--norm', default=False, action="store_true", default=False,
help="whether to do normalization over node embedding after training") action="store_true",
help="Whether save dat in pt format or npy",
)
parser.add_argument(
"--output_emb_file",
type=str,
default="emb.npy",
help="path of the output npy embedding file",
)
parser.add_argument(
"--map_file",
type=str,
default="nodeid_to_index.pickle",
help="path of the mapping dict that maps node ids to embedding index",
)
parser.add_argument(
"--norm",
default=False,
action="store_true",
help="whether to do normalization over node embedding after training",
)
# model parameters # model parameters
parser.add_argument('--dim', default=128, type=int, parser.add_argument(
help="embedding dimensions") "--dim", default=128, type=int, help="embedding dimensions"
parser.add_argument('--window_size', default=5, type=int, )
help="context window size") parser.add_argument(
parser.add_argument('--use_context_weight', default=False, action="store_true", "--window_size", default=5, type=int, help="context window size"
help="whether to add weights over nodes in the context window") )
parser.add_argument('--num_walks', default=10, type=int, parser.add_argument(
help="number of walks for each node") "--use_context_weight",
parser.add_argument('--negative', default=1, type=int, default=False,
help="negative samples for each positve node pair") action="store_true",
parser.add_argument('--batch_size', default=128, type=int, help="whether to add weights over nodes in the context window",
help="number of node sequences in each batch") )
parser.add_argument('--walk_length', default=80, type=int, parser.add_argument(
help="number of nodes in a sequence") "--num_walks",
parser.add_argument('--neg_weight', default=1., type=float, default=10,
help="negative weight") type=int,
parser.add_argument('--lap_norm', default=0.01, type=float, help="number of walks for each node",
help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size") )
parser.add_argument(
"--negative",
default=1,
type=int,
help="negative samples for each positve node pair",
)
parser.add_argument(
"--batch_size",
default=128,
type=int,
help="number of node sequences in each batch",
)
parser.add_argument(
"--walk_length",
default=80,
type=int,
help="number of nodes in a sequence",
)
parser.add_argument(
"--neg_weight", default=1.0, type=float, help="negative weight"
)
parser.add_argument(
"--lap_norm",
default=0.01,
type=float,
help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size",
)
# training parameters # training parameters
parser.add_argument('--print_interval', default=100, type=int, parser.add_argument(
help="number of batches between printing") "--print_interval",
parser.add_argument('--print_loss', default=False, action="store_true", default=100,
help="whether print loss during training") type=int,
parser.add_argument('--lr', default=0.2, type=float, help="number of batches between printing",
help="learning rate") )
parser.add_argument(
"--print_loss",
default=False,
action="store_true",
help="whether print loss during training",
)
parser.add_argument("--lr", default=0.2, type=float, help="learning rate")
# optimization settings # optimization settings
parser.add_argument('--mix', default=False, action="store_true", parser.add_argument(
help="mixed training with CPU and GPU") "--mix",
parser.add_argument('--gpus', type=int, default=[-1], nargs='+', default=False,
help='a list of active gpu ids, e.g. 0, used with --mix') action="store_true",
parser.add_argument('--only_cpu', default=False, action="store_true", help="mixed training with CPU and GPU",
help="training with CPU") )
parser.add_argument('--only_gpu', default=False, action="store_true", parser.add_argument(
help="training with GPU") "--gpus",
parser.add_argument('--async_update', default=False, action="store_true", type=int,
help="mixed training asynchronously, not recommended") default=[-1],
nargs="+",
parser.add_argument('--true_neg', default=False, action="store_true", help="a list of active gpu ids, e.g. 0, used with --mix",
help="If not specified, this program will use " )
"a faster negative sampling method, " parser.add_argument(
"but the samples might be false negative " "--only_cpu",
"with a small probability. If specified, " default=False,
"this program will generate a true negative sample table," action="store_true",
"and select from it when doing negative samling") help="training with CPU",
parser.add_argument('--num_threads', default=8, type=int, )
help="number of threads used for each CPU-core/GPU") parser.add_argument(
parser.add_argument('--num_sampler_threads', default=2, type=int, "--only_gpu",
help="number of threads used for sampling") default=False,
action="store_true",
parser.add_argument('--count_params', default=False, action="store_true", help="training with GPU",
help="count the params, exit once counting over") )
parser.add_argument(
"--async_update",
default=False,
action="store_true",
help="mixed training asynchronously, not recommended",
)
parser.add_argument(
"--true_neg",
default=False,
action="store_true",
help="If not specified, this program will use "
"a faster negative sampling method, "
"but the samples might be false negative "
"with a small probability. If specified, "
"this program will generate a true negative sample table,"
"and select from it when doing negative samling",
)
parser.add_argument(
"--num_threads",
default=8,
type=int,
help="number of threads used for each CPU-core/GPU",
)
parser.add_argument(
"--num_sampler_threads",
default=2,
type=int,
help="number of threads used for sampling",
)
parser.add_argument(
"--count_params",
default=False,
action="store_true",
help="count the params, exit once counting over",
)
args = parser.parse_args() args = parser.parse_args()
args.fast_neg = not args.true_neg args.fast_neg = not args.true_neg
......
""" load dataset from ogb """ """ load dataset from ogb """
import argparse import argparse
from ogb.linkproppred import DglLinkPropPredDataset
import time import time
def load_from_ogbl_with_name(name): from ogb.linkproppred import DglLinkPropPredDataset
choices = ['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation']
def load_from_ogbl_with_name(name):
choices = ["ogbl-collab", "ogbl-ddi", "ogbl-ppa", "ogbl-citation"]
assert name in choices, "name must be selected from " + str(choices) assert name in choices, "name must be selected from " + str(choices)
dataset = DglLinkPropPredDataset(name) dataset = DglLinkPropPredDataset(name)
return dataset[0] return dataset[0]
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, parser.add_argument(
choices=['ogbl-collab', 'ogbl-ddi', 'ogbl-ppa', 'ogbl-citation'], "--name",
default='ogbl-collab', type=str,
help="name of datasets by ogb") choices=["ogbl-collab", "ogbl-ddi", "ogbl-ppa", "ogbl-citation"],
default="ogbl-collab",
help="name of datasets by ogb",
)
args = parser.parse_args() args = parser.parse_args()
print("loading graph... it might take some time") print("loading graph... it might take some time")
...@@ -23,28 +29,32 @@ if __name__ == "__main__": ...@@ -23,28 +29,32 @@ if __name__ == "__main__":
g = load_from_ogbl_with_name(name=name) g = load_from_ogbl_with_name(name=name)
try: try:
w = g.edata['edge_weight'] w = g.edata["edge_weight"]
weighted = True weighted = True
except: except:
weighted = False weighted = False
edge_num = g.edges()[0].shape[0] edge_num = g.edges()[0].shape[0]
src = list(g.edges()[0]) src = list(g.edges()[0])
tgt = list(g.edges()[1]) tgt = list(g.edges()[1])
if weighted: if weighted:
weight = list(g.edata['edge_weight']) weight = list(g.edata["edge_weight"])
print("writing...") print("writing...")
start_time = time.time() start_time = time.time()
with open(name + "-net.txt", "w") as f: with open(name + "-net.txt", "w") as f:
for i in range(edge_num): for i in range(edge_num):
if weighted: if weighted:
f.write(str(src[i].item()) + " "\ f.write(
+str(tgt[i].item()) + " "\ str(src[i].item())
+str(weight[i].item()) + "\n") + " "
+ str(tgt[i].item())
+ " "
+ str(weight[i].item())
+ "\n"
)
else: else:
f.write(str(src[i].item()) + " "\ f.write(
+str(tgt[i].item()) + " "\ str(src[i].item()) + " " + str(tgt[i].item()) + " " + "1\n"
+"1\n") )
print("writing used time: %d s" % int(time.time() - start_time)) print("writing used time: %d s" % int(time.time() - start_time))
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import random import random
import numpy as np import numpy as np
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.multiprocessing import Queue from torch.multiprocessing import Queue
from torch.nn import init
def init_emb2pos_index(walk_length, window_size, batch_size): def init_emb2pos_index(walk_length, window_size, batch_size):
''' select embedding of positive nodes from a batch of node embeddings """select embedding of positive nodes from a batch of node embeddings
Return Return
------ ------
index_emb_posu torch.LongTensor : the indices of u_embeddings index_emb_posu torch.LongTensor : the indices of u_embeddings
...@@ -20,12 +21,12 @@ def init_emb2pos_index(walk_length, window_size, batch_size): ...@@ -20,12 +21,12 @@ def init_emb2pos_index(walk_length, window_size, batch_size):
----- -----
# emb_u.shape: [batch_size * walk_length, dim] # emb_u.shape: [batch_size * walk_length, dim]
batch_emb2posu = torch.index_select(emb_u, 0, index_emb_posu) batch_emb2posu = torch.index_select(emb_u, 0, index_emb_posu)
''' """
idx_list_u = [] idx_list_u = []
idx_list_v = [] idx_list_v = []
for b in range(batch_size): for b in range(batch_size):
for i in range(walk_length): for i in range(walk_length):
for j in range(i-window_size, i): for j in range(i - window_size, i):
if j >= 0: if j >= 0:
idx_list_u.append(j + b * walk_length) idx_list_u.append(j + b * walk_length)
idx_list_v.append(i + b * walk_length) idx_list_v.append(i + b * walk_length)
...@@ -40,10 +41,11 @@ def init_emb2pos_index(walk_length, window_size, batch_size): ...@@ -40,10 +41,11 @@ def init_emb2pos_index(walk_length, window_size, batch_size):
return index_emb_posu, index_emb_posv return index_emb_posu, index_emb_posv
def init_emb2neg_index(walk_length, window_size, negative, batch_size): def init_emb2neg_index(walk_length, window_size, negative, batch_size):
'''select embedding of negative nodes from a batch of node embeddings """select embedding of negative nodes from a batch of node embeddings
for fast negative sampling for fast negative sampling
Return Return
------ ------
index_emb_negu torch.LongTensor : the indices of u_embeddings index_emb_negu torch.LongTensor : the indices of u_embeddings
...@@ -53,21 +55,22 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size): ...@@ -53,21 +55,22 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
----- -----
# emb_u.shape: [batch_size * walk_length, dim] # emb_u.shape: [batch_size * walk_length, dim]
batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu) batch_emb2negu = torch.index_select(emb_u, 0, index_emb_negu)
''' """
idx_list_u = [] idx_list_u = []
for b in range(batch_size): for b in range(batch_size):
for i in range(walk_length): for i in range(walk_length):
for j in range(i-window_size, i): for j in range(i - window_size, i):
if j >= 0: if j >= 0:
idx_list_u += [i + b * walk_length] * negative idx_list_u += [i + b * walk_length] * negative
for j in range(i+1, i+1+window_size): for j in range(i + 1, i + 1 + window_size):
if j < walk_length: if j < walk_length:
idx_list_u += [i + b * walk_length] * negative idx_list_u += [i + b * walk_length] * negative
idx_list_v = list(range(batch_size * walk_length))\ idx_list_v = (
* negative * window_size * 2 list(range(batch_size * walk_length)) * negative * window_size * 2
)
random.shuffle(idx_list_v) random.shuffle(idx_list_v)
idx_list_v = idx_list_v[:len(idx_list_u)] idx_list_v = idx_list_v[: len(idx_list_u)]
# [bs * walk_length * negative] # [bs * walk_length * negative]
index_emb_negu = torch.LongTensor(idx_list_u) index_emb_negu = torch.LongTensor(idx_list_u)
...@@ -75,42 +78,46 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size): ...@@ -75,42 +78,46 @@ def init_emb2neg_index(walk_length, window_size, negative, batch_size):
return index_emb_negu, index_emb_negv return index_emb_negu, index_emb_negv
def init_weight(walk_length, window_size, batch_size): def init_weight(walk_length, window_size, batch_size):
''' init context weight ''' """init context weight"""
weight = [] weight = []
for b in range(batch_size): for b in range(batch_size):
for i in range(walk_length): for i in range(walk_length):
for j in range(i-window_size, i): for j in range(i - window_size, i):
if j >= 0: if j >= 0:
weight.append(1. - float(i - j - 1)/float(window_size)) weight.append(1.0 - float(i - j - 1) / float(window_size))
for j in range(i + 1, i + 1 + window_size): for j in range(i + 1, i + 1 + window_size):
if j < walk_length: if j < walk_length:
weight.append(1. - float(j - i - 1)/float(window_size)) weight.append(1.0 - float(j - i - 1) / float(window_size))
# [num_pos * batch_size] # [num_pos * batch_size]
return torch.Tensor(weight).unsqueeze(1) return torch.Tensor(weight).unsqueeze(1)
def init_empty_grad(emb_dimension, walk_length, batch_size): def init_empty_grad(emb_dimension, walk_length, batch_size):
""" initialize gradient matrix """ """initialize gradient matrix"""
grad_u = torch.zeros((batch_size * walk_length, emb_dimension)) grad_u = torch.zeros((batch_size * walk_length, emb_dimension))
grad_v = torch.zeros((batch_size * walk_length, emb_dimension)) grad_v = torch.zeros((batch_size * walk_length, emb_dimension))
return grad_u, grad_v return grad_u, grad_v
def adam(grad, state_sum, nodes, lr, device, only_gpu): def adam(grad, state_sum, nodes, lr, device, only_gpu):
""" calculate gradients according to adam """ """calculate gradients according to adam"""
grad_sum = (grad * grad).mean(1) grad_sum = (grad * grad).mean(1)
if not only_gpu: if not only_gpu:
grad_sum = grad_sum.cpu() grad_sum = grad_sum.cpu()
state_sum.index_add_(0, nodes, grad_sum) # cpu state_sum.index_add_(0, nodes, grad_sum) # cpu
std = state_sum[nodes].to(device) # gpu std = state_sum[nodes].to(device) # gpu
std_values = std.sqrt_().add_(1e-10).unsqueeze(1) std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
grad = (lr * grad / std_values) # gpu grad = lr * grad / std_values # gpu
return grad return grad
def async_update(num_threads, model, queue): def async_update(num_threads, model, queue):
""" asynchronous embedding update """ """asynchronous embedding update"""
torch.set_num_threads(num_threads) torch.set_num_threads(num_threads)
while True: while True:
(grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get() (grad_u, grad_v, grad_v_neg, nodes, neg_nodes) = queue.get()
...@@ -120,12 +127,17 @@ def async_update(num_threads, model, queue): ...@@ -120,12 +127,17 @@ def async_update(num_threads, model, queue):
model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u) model.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v) model.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None: if neg_nodes is not None:
model.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg) model.v_embeddings.weight.data.index_add_(
0, neg_nodes.view(-1), grad_v_neg
)
class SkipGramModel(nn.Module): class SkipGramModel(nn.Module):
""" Negative sampling based skip-gram """ """Negative sampling based skip-gram"""
def __init__(self,
emb_size, def __init__(
self,
emb_size,
emb_dimension, emb_dimension,
walk_length, walk_length,
window_size, window_size,
...@@ -143,8 +155,8 @@ class SkipGramModel(nn.Module): ...@@ -143,8 +155,8 @@ class SkipGramModel(nn.Module):
use_context_weight, use_context_weight,
async_update, async_update,
num_threads, num_threads,
): ):
""" initialize embedding on CPU """initialize embedding on CPU
Paremeters Paremeters
---------- ----------
...@@ -185,16 +197,18 @@ class SkipGramModel(nn.Module): ...@@ -185,16 +197,18 @@ class SkipGramModel(nn.Module):
self.use_context_weight = use_context_weight self.use_context_weight = use_context_weight
self.async_update = async_update self.async_update = async_update
self.num_threads = num_threads self.num_threads = num_threads
# initialize the device as cpu # initialize the device as cpu
self.device = torch.device("cpu") self.device = torch.device("cpu")
# content embedding # content embedding
self.u_embeddings = nn.Embedding( self.u_embeddings = nn.Embedding(
self.emb_size, self.emb_dimension, sparse=True) self.emb_size, self.emb_dimension, sparse=True
)
# context embedding # context embedding
self.v_embeddings = nn.Embedding( self.v_embeddings = nn.Embedding(
self.emb_size, self.emb_dimension, sparse=True) self.emb_size, self.emb_dimension, sparse=True
)
# initialze embedding # initialze embedding
initrange = 1.0 / self.emb_dimension initrange = 1.0 / self.emb_dimension
init.uniform_(self.u_embeddings.weight.data, -initrange, initrange) init.uniform_(self.u_embeddings.weight.data, -initrange, initrange)
...@@ -202,28 +216,26 @@ class SkipGramModel(nn.Module): ...@@ -202,28 +216,26 @@ class SkipGramModel(nn.Module):
# lookup_table is used for fast sigmoid computing # lookup_table is used for fast sigmoid computing
self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01)) self.lookup_table = torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
self.lookup_table[0] = 0. self.lookup_table[0] = 0.0
self.lookup_table[-1] = 1. self.lookup_table[-1] = 1.0
if self.record_loss: if self.record_loss:
self.logsigmoid_table = torch.log(torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))) self.logsigmoid_table = torch.log(
torch.sigmoid(torch.arange(-6.01, 6.01, 0.01))
)
self.loss = [] self.loss = []
# indexes to select positive/negative node pairs from batch_walks # indexes to select positive/negative node pairs from batch_walks
self.index_emb_posu, self.index_emb_posv = init_emb2pos_index( self.index_emb_posu, self.index_emb_posv = init_emb2pos_index(
self.walk_length, self.walk_length, self.window_size, self.batch_size
self.window_size, )
self.batch_size)
self.index_emb_negu, self.index_emb_negv = init_emb2neg_index( self.index_emb_negu, self.index_emb_negv = init_emb2neg_index(
self.walk_length, self.walk_length, self.window_size, self.negative, self.batch_size
self.window_size, )
self.negative,
self.batch_size)
if self.use_context_weight: if self.use_context_weight:
self.context_weight = init_weight( self.context_weight = init_weight(
self.walk_length, self.walk_length, self.window_size, self.batch_size
self.window_size, )
self.batch_size)
# adam # adam
self.state_sum_u = torch.zeros(self.emb_size) self.state_sum_u = torch.zeros(self.emb_size)
...@@ -231,32 +243,31 @@ class SkipGramModel(nn.Module): ...@@ -231,32 +243,31 @@ class SkipGramModel(nn.Module):
# gradients of nodes in batch_walks # gradients of nodes in batch_walks
self.grad_u, self.grad_v = init_empty_grad( self.grad_u, self.grad_v = init_empty_grad(
self.emb_dimension, self.emb_dimension, self.walk_length, self.batch_size
self.walk_length, )
self.batch_size)
def create_async_update(self): def create_async_update(self):
""" Set up the async update subprocess. """Set up the async update subprocess."""
"""
self.async_q = Queue(1) self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.num_threads, self, self.async_q)) self.async_p = mp.Process(
target=async_update, args=(self.num_threads, self, self.async_q)
)
self.async_p.start() self.async_p.start()
def finish_async_update(self): def finish_async_update(self):
""" Notify the async update subprocess to quit. """Notify the async update subprocess to quit."""
"""
self.async_q.put((None, None, None, None, None)) self.async_q.put((None, None, None, None, None))
self.async_p.join() self.async_p.join()
def share_memory(self): def share_memory(self):
""" share the parameters across subprocesses """ """share the parameters across subprocesses"""
self.u_embeddings.weight.share_memory_() self.u_embeddings.weight.share_memory_()
self.v_embeddings.weight.share_memory_() self.v_embeddings.weight.share_memory_()
self.state_sum_u.share_memory_() self.state_sum_u.share_memory_()
self.state_sum_v.share_memory_() self.state_sum_v.share_memory_()
def set_device(self, gpu_id): def set_device(self, gpu_id):
""" set gpu device """ """set gpu device"""
self.device = torch.device("cuda:%d" % gpu_id) self.device = torch.device("cuda:%d" % gpu_id)
print("The device is", self.device) print("The device is", self.device)
self.lookup_table = self.lookup_table.to(self.device) self.lookup_table = self.lookup_table.to(self.device)
...@@ -272,7 +283,7 @@ class SkipGramModel(nn.Module): ...@@ -272,7 +283,7 @@ class SkipGramModel(nn.Module):
self.context_weight = self.context_weight.to(self.device) self.context_weight = self.context_weight.to(self.device)
def all_to_device(self, gpu_id): def all_to_device(self, gpu_id):
""" move all of the parameters to a single GPU """ """move all of the parameters to a single GPU"""
self.device = torch.device("cuda:%d" % gpu_id) self.device = torch.device("cuda:%d" % gpu_id)
self.set_device(gpu_id) self.set_device(gpu_id)
self.u_embeddings = self.u_embeddings.cuda(gpu_id) self.u_embeddings = self.u_embeddings.cuda(gpu_id)
...@@ -281,17 +292,17 @@ class SkipGramModel(nn.Module): ...@@ -281,17 +292,17 @@ class SkipGramModel(nn.Module):
self.state_sum_v = self.state_sum_v.to(self.device) self.state_sum_v = self.state_sum_v.to(self.device)
def fast_sigmoid(self, score): def fast_sigmoid(self, score):
""" do fast sigmoid by looking up in a pre-defined table """ """do fast sigmoid by looking up in a pre-defined table"""
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.lookup_table[idx] return self.lookup_table[idx]
def fast_logsigmoid(self, score): def fast_logsigmoid(self, score):
""" do fast logsigmoid by looking up in a pre-defined table """ """do fast logsigmoid by looking up in a pre-defined table"""
idx = torch.floor((score + 6.01) / 0.01).long() idx = torch.floor((score + 6.01) / 0.01).long()
return self.logsigmoid_table[idx] return self.logsigmoid_table[idx]
def fast_learn(self, batch_walks, neg_nodes=None): def fast_learn(self, batch_walks, neg_nodes=None):
""" Learn a batch of random walks in a fast way. It has the following features: """Learn a batch of random walks in a fast way. It has the following features:
1. It calculating the gradients directly without the forward operation. 1. It calculating the gradients directly without the forward operation.
2. It does sigmoid by a looking up table. 2. It does sigmoid by a looking up table.
...@@ -310,7 +321,7 @@ class SkipGramModel(nn.Module): ...@@ -310,7 +321,7 @@ class SkipGramModel(nn.Module):
Usage example Usage example
------------- -------------
batch_walks = [torch.LongTensor([1,2,3,4]), batch_walks = [torch.LongTensor([1,2,3,4]),
torch.LongTensor([2,3,4,2])]) torch.LongTensor([2,3,4,2])])
lr = 0.01 lr = 0.01
neg_nodes = None neg_nodes = None
...@@ -326,16 +337,23 @@ class SkipGramModel(nn.Module): ...@@ -326,16 +337,23 @@ class SkipGramModel(nn.Module):
nodes = nodes.to(self.device) nodes = nodes.to(self.device)
if neg_nodes is not None: if neg_nodes is not None:
neg_nodes = neg_nodes.to(self.device) neg_nodes = neg_nodes.to(self.device)
emb_u = self.u_embeddings(nodes).view(-1, self.emb_dimension).to(self.device) emb_u = (
emb_v = self.v_embeddings(nodes).view(-1, self.emb_dimension).to(self.device) self.u_embeddings(nodes)
.view(-1, self.emb_dimension)
.to(self.device)
)
emb_v = (
self.v_embeddings(nodes)
.view(-1, self.emb_dimension)
.to(self.device)
)
## Postive ## Postive
bs = len(batch_walks) bs = len(batch_walks)
if bs < self.batch_size: if bs < self.batch_size:
index_emb_posu, index_emb_posv = init_emb2pos_index( index_emb_posu, index_emb_posv = init_emb2pos_index(
self.walk_length, self.walk_length, self.window_size, bs
self.window_size, )
bs)
index_emb_posu = index_emb_posu.to(self.device) index_emb_posu = index_emb_posu.to(self.device)
index_emb_posv = index_emb_posv.to(self.device) index_emb_posv = index_emb_posv.to(self.device)
else: else:
...@@ -356,8 +374,12 @@ class SkipGramModel(nn.Module): ...@@ -356,8 +374,12 @@ class SkipGramModel(nn.Module):
# [batch_size * num_pos, dim] # [batch_size * num_pos, dim]
if self.lap_norm > 0: if self.lap_norm > 0:
grad_u_pos = score * emb_pos_v + self.lap_norm * (emb_pos_v - emb_pos_u) grad_u_pos = score * emb_pos_v + self.lap_norm * (
grad_v_pos = score * emb_pos_u + self.lap_norm * (emb_pos_u - emb_pos_v) emb_pos_v - emb_pos_u
)
grad_v_pos = score * emb_pos_u + self.lap_norm * (
emb_pos_u - emb_pos_v
)
else: else:
grad_u_pos = score * emb_pos_v grad_u_pos = score * emb_pos_v
grad_v_pos = score * emb_pos_u grad_v_pos = score * emb_pos_u
...@@ -365,9 +387,8 @@ class SkipGramModel(nn.Module): ...@@ -365,9 +387,8 @@ class SkipGramModel(nn.Module):
if self.use_context_weight: if self.use_context_weight:
if bs < self.batch_size: if bs < self.batch_size:
context_weight = init_weight( context_weight = init_weight(
self.walk_length, self.walk_length, self.window_size, bs
self.window_size, ).to(self.device)
bs).to(self.device)
else: else:
context_weight = self.context_weight context_weight = self.context_weight
grad_u_pos *= context_weight grad_u_pos *= context_weight
...@@ -376,9 +397,8 @@ class SkipGramModel(nn.Module): ...@@ -376,9 +397,8 @@ class SkipGramModel(nn.Module):
# [batch_size * walk_length, dim] # [batch_size * walk_length, dim]
if bs < self.batch_size: if bs < self.batch_size:
grad_u, grad_v = init_empty_grad( grad_u, grad_v = init_empty_grad(
self.emb_dimension, self.emb_dimension, self.walk_length, bs
self.walk_length, )
bs)
grad_u = grad_u.to(self.device) grad_u = grad_u.to(self.device)
grad_v = grad_v.to(self.device) grad_v = grad_v.to(self.device)
else: else:
...@@ -394,14 +414,15 @@ class SkipGramModel(nn.Module): ...@@ -394,14 +414,15 @@ class SkipGramModel(nn.Module):
## Negative ## Negative
if bs < self.batch_size: if bs < self.batch_size:
index_emb_negu, index_emb_negv = init_emb2neg_index( index_emb_negu, index_emb_negv = init_emb2neg_index(
self.walk_length, self.window_size, self.negative, bs) self.walk_length, self.window_size, self.negative, bs
)
index_emb_negu = index_emb_negu.to(self.device) index_emb_negu = index_emb_negu.to(self.device)
index_emb_negv = index_emb_negv.to(self.device) index_emb_negv = index_emb_negv.to(self.device)
else: else:
index_emb_negu = self.index_emb_negu index_emb_negu = self.index_emb_negu
index_emb_negv = self.index_emb_negv index_emb_negv = self.index_emb_negv
emb_neg_u = torch.index_select(emb_u, 0, index_emb_negu) emb_neg_u = torch.index_select(emb_u, 0, index_emb_negu)
if neg_nodes is None: if neg_nodes is None:
emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv) emb_neg_v = torch.index_select(emb_v, 0, index_emb_negv)
else: else:
...@@ -411,9 +432,13 @@ class SkipGramModel(nn.Module): ...@@ -411,9 +432,13 @@ class SkipGramModel(nn.Module):
neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1) neg_score = torch.sum(torch.mul(emb_neg_u, emb_neg_v), dim=1)
neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.clamp(neg_score, max=6, min=-6)
# [batch_size * walk_length * negative, 1] # [batch_size * walk_length * negative, 1]
score = - self.fast_sigmoid(neg_score).unsqueeze(1) score = -self.fast_sigmoid(neg_score).unsqueeze(1)
if self.record_loss: if self.record_loss:
self.loss.append(self.negative * self.neg_weight * torch.mean(self.fast_logsigmoid(-neg_score)).item()) self.loss.append(
self.negative
* self.neg_weight
* torch.mean(self.fast_logsigmoid(-neg_score)).item()
)
grad_u_neg = self.neg_weight * score * emb_neg_v grad_u_neg = self.neg_weight * score * emb_neg_v
grad_v_neg = self.neg_weight * score * emb_neg_u grad_v_neg = self.neg_weight * score * emb_neg_u
...@@ -426,10 +451,21 @@ class SkipGramModel(nn.Module): ...@@ -426,10 +451,21 @@ class SkipGramModel(nn.Module):
nodes = nodes.view(-1) nodes = nodes.view(-1)
# use adam optimizer # use adam optimizer
grad_u = adam(grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu) grad_u = adam(
grad_v = adam(grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu) grad_u, self.state_sum_u, nodes, lr, self.device, self.only_gpu
)
grad_v = adam(
grad_v, self.state_sum_v, nodes, lr, self.device, self.only_gpu
)
if neg_nodes is not None: if neg_nodes is not None:
grad_v_neg = adam(grad_v_neg, self.state_sum_v, neg_nodes, lr, self.device, self.only_gpu) grad_v_neg = adam(
grad_v_neg,
self.state_sum_v,
neg_nodes,
lr,
self.device,
self.only_gpu,
)
if self.mixed_train: if self.mixed_train:
grad_u = grad_u.cpu() grad_u = grad_u.cpu()
...@@ -447,16 +483,18 @@ class SkipGramModel(nn.Module): ...@@ -447,16 +483,18 @@ class SkipGramModel(nn.Module):
neg_nodes.share_memory_() neg_nodes.share_memory_()
grad_v_neg.share_memory_() grad_v_neg.share_memory_()
self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes)) self.async_q.put((grad_u, grad_v, grad_v_neg, nodes, neg_nodes))
if not self.async_update: if not self.async_update:
self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u) self.u_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_u)
self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v) self.v_embeddings.weight.data.index_add_(0, nodes.view(-1), grad_v)
if neg_nodes is not None: if neg_nodes is not None:
self.v_embeddings.weight.data.index_add_(0, neg_nodes.view(-1), grad_v_neg) self.v_embeddings.weight.data.index_add_(
0, neg_nodes.view(-1), grad_v_neg
)
return return
def forward(self, pos_u, pos_v, neg_v): def forward(self, pos_u, pos_v, neg_v):
''' Do forward and backward. It is designed for future use. ''' """Do forward and backward. It is designed for future use."""
emb_u = self.u_embeddings(pos_u) emb_u = self.u_embeddings(pos_u)
emb_v = self.v_embeddings(pos_v) emb_v = self.v_embeddings(pos_v)
emb_neg_v = self.v_embeddings(neg_v) emb_neg_v = self.v_embeddings(neg_v)
...@@ -469,11 +507,11 @@ class SkipGramModel(nn.Module): ...@@ -469,11 +507,11 @@ class SkipGramModel(nn.Module):
neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.clamp(neg_score, max=6, min=-6)
neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1) neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)
#return torch.mean(score + neg_score) # return torch.mean(score + neg_score)
return torch.sum(score), torch.sum(neg_score) return torch.sum(score), torch.sum(neg_score)
def save_embedding(self, dataset, file_name): def save_embedding(self, dataset, file_name):
""" Write embedding to local file. Only used when node ids are numbers. """Write embedding to local file. Only used when node ids are numbers.
Parameter Parameter
--------- ---------
...@@ -482,42 +520,55 @@ class SkipGramModel(nn.Module): ...@@ -482,42 +520,55 @@ class SkipGramModel(nn.Module):
""" """
embedding = self.u_embeddings.weight.cpu().data.numpy() embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm: if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1) embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(
-1, 1
)
np.save(file_name, embedding) np.save(file_name, embedding)
def save_embedding_pt(self, dataset, file_name): def save_embedding_pt(self, dataset, file_name):
""" For ogb leaderboard. """For ogb leaderboard."""
"""
try: try:
max_node_id = max(dataset.node2id.keys()) max_node_id = max(dataset.node2id.keys())
if max_node_id + 1 != self.emb_size: if max_node_id + 1 != self.emb_size:
print("WARNING: The node ids are not serial.") print("WARNING: The node ids are not serial.")
embedding = torch.zeros(max_node_id + 1, self.emb_dimension) embedding = torch.zeros(max_node_id + 1, self.emb_dimension)
index = torch.LongTensor(list(map(lambda id: dataset.id2node[id], list(range(self.emb_size))))) index = torch.LongTensor(
list(
map(
lambda id: dataset.id2node[id],
list(range(self.emb_size)),
)
)
)
embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data) embedding.index_add_(0, index, self.u_embeddings.weight.cpu().data)
if self.norm: if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1) embedding /= torch.sqrt(
torch.sum(embedding.mul(embedding), 1) + 1e-6
).unsqueeze(1)
torch.save(embedding, file_name) torch.save(embedding, file_name)
except: except:
self.save_embedding_pt_dgl_graph(dataset, file_name) self.save_embedding_pt_dgl_graph(dataset, file_name)
def save_embedding_pt_dgl_graph(self, dataset, file_name): def save_embedding_pt_dgl_graph(self, dataset, file_name):
""" For ogb leaderboard """ """For ogb leaderboard"""
embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data) embedding = torch.zeros_like(self.u_embeddings.weight.cpu().data)
valid_seeds = torch.LongTensor(dataset.valid_seeds) valid_seeds = torch.LongTensor(dataset.valid_seeds)
valid_embedding = self.u_embeddings.weight.cpu().data.index_select(0, valid_embedding = self.u_embeddings.weight.cpu().data.index_select(
valid_seeds) 0, valid_seeds
)
embedding.index_add_(0, valid_seeds, valid_embedding) embedding.index_add_(0, valid_seeds, valid_embedding)
if self.norm: if self.norm:
embedding /= torch.sqrt(torch.sum(embedding.mul(embedding), 1) + 1e-6).unsqueeze(1) embedding /= torch.sqrt(
torch.sum(embedding.mul(embedding), 1) + 1e-6
).unsqueeze(1)
torch.save(embedding, file_name) torch.save(embedding, file_name)
def save_embedding_txt(self, dataset, file_name): def save_embedding_txt(self, dataset, file_name):
""" Write embedding to local file. For future use. """Write embedding to local file. For future use.
Parameter Parameter
--------- ---------
...@@ -526,9 +577,11 @@ class SkipGramModel(nn.Module): ...@@ -526,9 +577,11 @@ class SkipGramModel(nn.Module):
""" """
embedding = self.u_embeddings.weight.cpu().data.numpy() embedding = self.u_embeddings.weight.cpu().data.numpy()
if self.norm: if self.norm:
embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(-1, 1) embedding /= np.sqrt(np.sum(embedding * embedding, 1)).reshape(
with open(file_name, 'w') as f: -1, 1
f.write('%d %d\n' % (self.emb_size, self.emb_dimension)) )
with open(file_name, "w") as f:
f.write("%d %d\n" % (self.emb_size, self.emb_dimension))
for wid in range(self.emb_size): for wid in range(self.emb_size):
e = ' '.join(map(lambda x: str(x), embedding[wid])) e = " ".join(map(lambda x: str(x), embedding[wid]))
f.write('%s %s\n' % (str(dataset.id2node[wid]), e)) f.write("%s %s\n" % (str(dataset.id2node[wid]), e))
import os import os
import pickle
import random
import time
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import pickle
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from dgl.data.utils import download, _get_dgl_url, get_download_dir, extract_archive from utils import shuffle_walks
import random
import time
import dgl import dgl
from dgl.data.utils import (
_get_dgl_url,
download,
extract_archive,
get_download_dir,
)
from utils import shuffle_walks
def ReadTxtNet(file_path="", undirected=True): def ReadTxtNet(file_path="", undirected=True):
""" Read the txt network file. """Read the txt network file.
Notations: The network is unweighted. Notations: The network is unweighted.
Parameters Parameters
...@@ -23,16 +30,20 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -23,16 +30,20 @@ def ReadTxtNet(file_path="", undirected=True):
Return Return
------ ------
net dict : a dict recording the connections in the graph net dict : a dict recording the connections in the graph
node2id dict : a dict mapping the nodes to their embedding indices node2id dict : a dict mapping the nodes to their embedding indices
id2node dict : a dict mapping nodes embedding indices to the nodes id2node dict : a dict mapping nodes embedding indices to the nodes
""" """
if file_path == 'youtube' or file_path == 'blog': if file_path == "youtube" or file_path == "blog":
name = file_path name = file_path
dir = get_download_dir() dir = get_download_dir()
zip_file_path='{}/{}.zip'.format(dir, name) zip_file_path = "{}/{}.zip".format(dir, name)
download(_get_dgl_url(os.path.join('dataset/DeepWalk/', '{}.zip'.format(file_path))), path=zip_file_path) download(
extract_archive(zip_file_path, _get_dgl_url(
'{}/{}'.format(dir, name)) os.path.join("dataset/DeepWalk/", "{}.zip".format(file_path))
),
path=zip_file_path,
)
extract_archive(zip_file_path, "{}/{}".format(dir, name))
file_path = "{}/{}/{}-net.txt".format(dir, name, name) file_path = "{}/{}/{}-net.txt".format(dir, name, name)
node2id = {} node2id = {}
...@@ -46,7 +57,10 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -46,7 +57,10 @@ def ReadTxtNet(file_path="", undirected=True):
with open(file_path, "r") as f: with open(file_path, "r") as f:
for line in f.readlines(): for line in f.readlines():
tup = list(map(int, line.strip().split(" "))) tup = list(map(int, line.strip().split(" ")))
assert len(tup) in [2, 3], "The format of network file is unrecognizable." assert len(tup) in [
2,
3,
], "The format of network file is unrecognizable."
if len(tup) == 3: if len(tup) == 3:
n1, n2, w = tup n1, n2, w = tup
elif len(tup) == 2: elif len(tup) == 2:
...@@ -73,7 +87,7 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -73,7 +87,7 @@ def ReadTxtNet(file_path="", undirected=True):
src.append(n1) src.append(n1)
dst.append(n2) dst.append(n2)
weight.append(w) weight.append(w)
if undirected: if undirected:
if n2 not in net: if n2 not in net:
net[n2] = {n1: w} net[n2] = {n1: w}
...@@ -90,16 +104,15 @@ def ReadTxtNet(file_path="", undirected=True): ...@@ -90,16 +104,15 @@ def ReadTxtNet(file_path="", undirected=True):
print("edge num: %d" % len(src)) print("edge num: %d" % len(src))
assert max(net.keys()) == len(net) - 1, "error reading net, quit" assert max(net.keys()) == len(net) - 1, "error reading net, quit"
sm = sp.coo_matrix( sm = sp.coo_matrix((np.array(weight), (src, dst)), dtype=np.float32)
(np.array(weight), (src, dst)),
dtype=np.float32)
return net, node2id, id2node, sm return net, node2id, id2node, sm
def net2graph(net_sm): def net2graph(net_sm):
""" Transform the network to DGL graph """Transform the network to DGL graph
Return Return
------ ------
G DGLGraph : graph by DGL G DGLGraph : graph by DGL
""" """
...@@ -110,30 +123,34 @@ def net2graph(net_sm): ...@@ -110,30 +123,34 @@ def net2graph(net_sm):
print("Building DGLGraph in %.2fs" % t) print("Building DGLGraph in %.2fs" % t)
return G return G
def make_undirected(G): def make_undirected(G):
#G.readonly(False) # G.readonly(False)
G.add_edges(G.edges()[1], G.edges()[0]) G.add_edges(G.edges()[1], G.edges()[0])
return G return G
def find_connected_nodes(G): def find_connected_nodes(G):
nodes = G.out_degrees().nonzero().squeeze(-1) nodes = G.out_degrees().nonzero().squeeze(-1)
return nodes return nodes
class DeepwalkDataset: class DeepwalkDataset:
def __init__(self, def __init__(
net_file, self,
map_file, net_file,
walk_length, map_file,
window_size, walk_length,
num_walks, window_size,
batch_size, num_walks,
negative=5, batch_size,
gpus=[0], negative=5,
fast_neg=True, gpus=[0],
ogbl_name="", fast_neg=True,
load_from_ogbl=False, ogbl_name="",
): load_from_ogbl=False,
""" This class has the following functions: ):
"""This class has the following functions:
1. Transform the txt network file into DGL graph; 1. Transform the txt network file into DGL graph;
2. Generate random walk sequences for the trainer; 2. Generate random walk sequences for the trainer;
3. Provide the negative table if the user hopes to sample negative 3. Provide the negative table if the user hopes to sample negative
...@@ -158,8 +175,11 @@ class DeepwalkDataset: ...@@ -158,8 +175,11 @@ class DeepwalkDataset:
self.fast_neg = fast_neg self.fast_neg = fast_neg
if load_from_ogbl: if load_from_ogbl:
assert len(gpus) == 1, "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)." assert (
len(gpus) == 1
), "ogb.linkproppred is not compatible with multi-gpu training (CUDA error)."
from load_dataset import load_from_ogbl_with_name from load_dataset import load_from_ogbl_with_name
self.G = load_from_ogbl_with_name(ogbl_name) self.G = load_from_ogbl_with_name(ogbl_name)
self.G = make_undirected(self.G) self.G = make_undirected(self.G)
else: else:
...@@ -173,12 +193,18 @@ class DeepwalkDataset: ...@@ -173,12 +193,18 @@ class DeepwalkDataset:
start = time.time() start = time.time()
self.valid_seeds = find_connected_nodes(self.G) self.valid_seeds = find_connected_nodes(self.G)
if len(self.valid_seeds) != self.num_nodes: if len(self.valid_seeds) != self.num_nodes:
print("WARNING: The node ids are not serial. Some nodes are invalid.") print(
"WARNING: The node ids are not serial. Some nodes are invalid."
)
seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks) seeds = torch.cat([torch.LongTensor(self.valid_seeds)] * num_walks)
self.seeds = torch.split(shuffle_walks(seeds), self.seeds = torch.split(
int(np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)), shuffle_walks(seeds),
0) int(
np.ceil(len(self.valid_seeds) * self.num_walks / self.num_procs)
),
0,
)
end = time.time() end = time.time()
t = end - start t = end - start
print("%d seeds in %.2fs" % (len(seeds), t)) print("%d seeds in %.2fs" % (len(seeds), t))
...@@ -190,7 +216,7 @@ class DeepwalkDataset: ...@@ -190,7 +216,7 @@ class DeepwalkDataset:
node_degree /= np.sum(node_degree) node_degree /= np.sum(node_degree)
node_degree = np.array(node_degree * 1e8, dtype=np.int) node_degree = np.array(node_degree * 1e8, dtype=np.int)
self.neg_table = [] self.neg_table = []
for idx, node in enumerate(self.valid_seeds): for idx, node in enumerate(self.valid_seeds):
self.neg_table += [node] * node_degree[idx] self.neg_table += [node] * node_degree[idx]
self.neg_table_size = len(self.neg_table) self.neg_table_size = len(self.neg_table)
...@@ -198,18 +224,19 @@ class DeepwalkDataset: ...@@ -198,18 +224,19 @@ class DeepwalkDataset:
del node_degree del node_degree
def create_sampler(self, i): def create_sampler(self, i):
""" create random walk sampler """ """create random walk sampler"""
return DeepwalkSampler(self.G, self.seeds[i], self.walk_length) return DeepwalkSampler(self.G, self.seeds[i], self.walk_length)
def save_mapping(self, map_file): def save_mapping(self, map_file):
""" save the mapping dict that maps node IDs to embedding indices """ """save the mapping dict that maps node IDs to embedding indices"""
with open(map_file, "wb") as f: with open(map_file, "wb") as f:
pickle.dump(self.node2id, f) pickle.dump(self.node2id, f)
class DeepwalkSampler(object): class DeepwalkSampler(object):
def __init__(self, G, seeds, walk_length): def __init__(self, G, seeds, walk_length):
""" random walk sampler """random walk sampler
Parameter Parameter
--------- ---------
G dgl.Graph : the input graph G dgl.Graph : the input graph
...@@ -219,7 +246,9 @@ class DeepwalkSampler(object): ...@@ -219,7 +246,9 @@ class DeepwalkSampler(object):
self.G = G self.G = G
self.seeds = seeds self.seeds = seeds
self.walk_length = walk_length self.walk_length = walk_length
def sample(self, seeds): def sample(self, seeds):
walks = dgl.sampling.random_walk(self.G, seeds, length=self.walk_length-1)[0] walks = dgl.sampling.random_walk(
self.G, seeds, length=self.walk_length - 1
)[0]
return walks return walks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment