Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression ...@@ -10,7 +10,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.multiclass import OneVsRestClassifier from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import OneHotEncoder, normalize from sklearn.preprocessing import normalize, OneHotEncoder
def repeat(n_times): def repeat(n_times):
......
...@@ -75,7 +75,6 @@ else: ...@@ -75,7 +75,6 @@ else:
args.device = "cpu" args.device = "cpu"
if __name__ == "__main__": if __name__ == "__main__":
# Step 1: Load hyperparameters =================================================================== # # Step 1: Load hyperparameters =================================================================== #
lr = args.lr lr = args.lr
hid_dim = args.hid_dim hid_dim = args.hid_dim
......
import argparse import argparse
import warnings import warnings
import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from model import GRAND
import dgl
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from model import GRAND
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
def argument(): def argument():
parser = argparse.ArgumentParser(description="GRAND") parser = argparse.ArgumentParser(description="GRAND")
# data source params # data source params
...@@ -111,7 +110,6 @@ def consis_loss(logps, temp, lam): ...@@ -111,7 +110,6 @@ def consis_loss(logps, temp, lam):
if __name__ == "__main__": if __name__ == "__main__":
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset # Load from DGL dataset
args = argument() args = argument()
...@@ -175,7 +173,6 @@ if __name__ == "__main__": ...@@ -175,7 +173,6 @@ if __name__ == "__main__":
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
for epoch in range(args.epochs): for epoch in range(args.epochs):
"""Training""" """Training"""
model.train() model.train()
...@@ -204,7 +201,6 @@ if __name__ == "__main__": ...@@ -204,7 +201,6 @@ if __name__ == "__main__":
""" Validating """ """ Validating """
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
val_logits = model(graph, feats, False) val_logits = model(graph, feats, False)
loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx]) loss_val = F.nll_loss(val_logits[val_idx], labels[val_idx])
......
import dgl.function as fn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
def drop_node(feats, drop_rate, training): def drop_node(feats, drop_rate, training):
n = feats.shape[0] n = feats.shape[0]
drop_rates = th.FloatTensor(np.ones(n) * drop_rate) drop_rates = th.FloatTensor(np.ones(n) * drop_rate)
if training: if training:
masks = th.bernoulli(1.0 - drop_rates).unsqueeze(1) masks = th.bernoulli(1.0 - drop_rates).unsqueeze(1)
feats = masks.to(feats.device) * feats feats = masks.to(feats.device) * feats
...@@ -42,7 +39,6 @@ class MLP(nn.Module): ...@@ -42,7 +39,6 @@ class MLP(nn.Module):
self.layer2.reset_parameters() self.layer2.reset_parameters()
def forward(self, x): def forward(self, x):
if self.use_bn: if self.use_bn:
x = self.bn1(x) x = self.bn1(x)
x = self.input_dropout(x) x = self.input_dropout(x)
...@@ -68,7 +64,6 @@ def GRANDConv(graph, feats, order): ...@@ -68,7 +64,6 @@ def GRANDConv(graph, feats, order):
Propagation Steps Propagation Steps
""" """
with graph.local_scope(): with graph.local_scope():
"""Calculate Symmetric normalized adjacency matrix \hat{A}""" """Calculate Symmetric normalized adjacency matrix \hat{A}"""
degs = graph.in_degrees().float().clamp(min=1) degs = graph.in_degrees().float().clamp(min=1)
norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1) norm = th.pow(degs, -0.5).to(feats.device).unsqueeze(1)
...@@ -127,7 +122,6 @@ class GRAND(nn.Module): ...@@ -127,7 +122,6 @@ class GRAND(nn.Module):
hidden_droprate=0.0, hidden_droprate=0.0,
batchnorm=False, batchnorm=False,
): ):
super(GRAND, self).__init__() super(GRAND, self).__init__()
self.in_dim = in_dim self.in_dim = in_dim
self.hid_dim = hid_dim self.hid_dim = hid_dim
...@@ -143,7 +137,6 @@ class GRAND(nn.Module): ...@@ -143,7 +137,6 @@ class GRAND(nn.Module):
self.node_dropout = nn.Dropout(node_dropout) self.node_dropout = nn.Dropout(node_dropout)
def forward(self, graph, feats, training=True): def forward(self, graph, feats, training=True):
X = feats X = feats
S = self.S S = self.S
......
import dgl
import numpy as np import numpy as np
from ged import graph_edit_distance from ged import graph_edit_distance
import dgl
src1 = [0, 1, 2, 3, 4, 5] src1 = [0, 1, 2, 3, 4, 5]
dst1 = [1, 2, 3, 4, 5, 6] dst1 = [1, 2, 3, 4, 5, 6]
......
from copy import deepcopy
from heapq import heapify, heappop, heappush, nsmallest
import dgl import dgl
import numpy as np import numpy as np
from heapq import heappush, heappop, heapify, nsmallest
from copy import deepcopy
# We use lapjv implementation (https://github.com/src-d/lapjv) to solve assignment problem, because of its scalability # We use lapjv implementation (https://github.com/src-d/lapjv) to solve assignment problem, because of its scalability
# Also see https://github.com/berhane/LAP-solvers for benchmarking of LAP solvers # Also see https://github.com/berhane/LAP-solvers for benchmarking of LAP solvers
...@@ -247,7 +248,6 @@ class search_tree_node: ...@@ -247,7 +248,6 @@ class search_tree_node:
cost_matrix_nodes, cost_matrix_nodes,
cost_matrix_edges, cost_matrix_edges,
): ):
self.matched_cost = parent_matched_cost self.matched_cost = parent_matched_cost
self.future_approximate_cost = 0.0 self.future_approximate_cost = 0.0
self.matched_nodes = deepcopy(parent_matched_nodes) self.matched_nodes = deepcopy(parent_matched_nodes)
...@@ -1156,7 +1156,6 @@ def graph_edit_distance( ...@@ -1156,7 +1156,6 @@ def graph_edit_distance(
algorithm="bipartite", algorithm="bipartite",
max_beam_size=100, max_beam_size=100,
): ):
"""Returns GED (graph edit distance) between DGLGraphs G1 and G2. """Returns GED (graph edit distance) between DGLGraphs G1 and G2.
......
import dgl
import dgl.nn as dglnn
import sklearn.linear_model as lm import sklearn.linear_model as lm
import sklearn.metrics as skm import sklearn.metrics as skm
import torch as th import torch as th
...@@ -5,9 +7,6 @@ import torch.functional as F ...@@ -5,9 +7,6 @@ import torch.functional as F
import torch.nn as nn import torch.nn as nn
import tqdm import tqdm
import dgl
import dgl.nn as dglnn
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__( def __init__(
......
import torch as th
import dgl import dgl
import torch as th
class NegativeSampler(object): class NegativeSampler(object):
......
import argparse
import glob
import os
import sys
import time
import dgl import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import dgl.nn.pytorch as dglnn
import dgl.function as fn
import time
import argparse
import tqdm import tqdm
import glob from model import compute_acc_unsupervised as compute_acc, SAGE
import os
from negative_sampler import NegativeSampler from negative_sampler import NegativeSampler
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from model import SAGE, compute_acc_unsupervised as compute_acc
import sys from pytorch_lightning.callbacks import Callback, ModelCheckpoint
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from load_graph import load_reddit, inductive_split, load_ogb sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import inductive_split, load_ogb, load_reddit
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def forward(self, block_outputs, pos_graph, neg_graph): def forward(self, block_outputs, pos_graph, neg_graph):
with pos_graph.local_scope(): with pos_graph.local_scope():
pos_graph.ndata['h'] = block_outputs pos_graph.ndata["h"] = block_outputs
pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score')) pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
pos_score = pos_graph.edata['score'] pos_score = pos_graph.edata["score"]
with neg_graph.local_scope(): with neg_graph.local_scope():
neg_graph.ndata['h'] = block_outputs neg_graph.ndata["h"] = block_outputs
neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score')) neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
neg_score = neg_graph.edata['score'] neg_score = neg_graph.edata["score"]
score = th.cat([pos_score, neg_score]) score = th.cat([pos_score, neg_score])
label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long() label = th.cat(
[th.ones_like(pos_score), th.zeros_like(neg_score)]
).long()
loss = F.binary_cross_entropy_with_logits(score, label.float()) loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss return loss
class SAGELightning(LightningModule): class SAGELightning(LightningModule):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, lr
n_hidden, ):
n_classes,
n_layers,
activation,
dropout,
lr):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.module = SAGE(in_feats, n_hidden, n_classes, n_layers, activation, dropout) self.module = SAGE(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
self.lr = lr self.lr = lr
self.loss_fcn = CrossEntropyLoss() self.loss_fcn = CrossEntropyLoss()
...@@ -57,18 +60,20 @@ class SAGELightning(LightningModule): ...@@ -57,18 +60,20 @@ class SAGELightning(LightningModule):
mfgs = [mfg.int().to(device) for mfg in mfgs] mfgs = [mfg.int().to(device) for mfg in mfgs]
pos_graph = pos_graph.to(device) pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device) neg_graph = neg_graph.to(device)
batch_inputs = mfgs[0].srcdata['features'] batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata['labels'] batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs) batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fcn(batch_pred, pos_graph, neg_graph) loss = self.loss_fcn(batch_pred, pos_graph, neg_graph)
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True) self.log(
"train_loss", loss, prog_bar=True, on_step=False, on_epoch=True
)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs] mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features'] batch_inputs = mfgs[0].srcdata["features"]
batch_labels = mfgs[-1].dstdata['labels'] batch_labels = mfgs[-1].dstdata["labels"]
batch_pred = self.module(mfgs, batch_inputs) batch_pred = self.module(mfgs, batch_inputs)
return batch_pred return batch_pred
...@@ -78,54 +83,73 @@ class SAGELightning(LightningModule): ...@@ -78,54 +83,73 @@ class SAGELightning(LightningModule):
class DataModule(LightningDataModule): class DataModule(LightningDataModule):
def __init__(self, dataset_name, data_cpu=False, fan_out=[10, 25], def __init__(
device=th.device('cpu'), batch_size=1000, num_workers=4): self,
dataset_name,
data_cpu=False,
fan_out=[10, 25],
device=th.device("cpu"),
batch_size=1000,
num_workers=4,
):
super().__init__() super().__init__()
if dataset_name == 'reddit': if dataset_name == "reddit":
g, n_classes = load_reddit() g, n_classes = load_reddit()
n_edges = g.num_edges() n_edges = g.num_edges()
reverse_eids = th.cat([ reverse_eids = th.cat(
th.arange(n_edges // 2, n_edges), [th.arange(n_edges // 2, n_edges), th.arange(0, n_edges // 2)]
th.arange(0, n_edges // 2)]) )
elif dataset_name == 'ogbn-products': elif dataset_name == "ogbn-products":
g, n_classes = load_ogb('ogbn-products') g, n_classes = load_ogb("ogbn-products")
n_edges = g.num_edges() n_edges = g.num_edges()
# The reverse edge of edge 0 in OGB products dataset is 1. # The reverse edge of edge 0 in OGB products dataset is 1.
# The reverse edge of edge 2 is 3. So on so forth. # The reverse edge of edge 2 is 3. So on so forth.
reverse_eids = th.arange(n_edges) ^ 1 reverse_eids = th.arange(n_edges) ^ 1
else: else:
raise ValueError('unknown dataset') raise ValueError("unknown dataset")
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0] train_nid = th.nonzero(g.ndata["train_mask"], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0] val_nid = th.nonzero(g.ndata["val_mask"], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0] test_nid = th.nonzero(
~(g.ndata["train_mask"] | g.ndata["val_mask"]), as_tuple=True
)[0]
sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out]) sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(_) for _ in fan_out]
)
dataloader_device = th.device('cpu') dataloader_device = th.device("cpu")
if not data_cpu: if not data_cpu:
train_nid = train_nid.to(device) train_nid = train_nid.to(device)
val_nid = val_nid.to(device) val_nid = val_nid.to(device)
test_nid = test_nid.to(device) test_nid = test_nid.to(device)
g = g.formats(['csc']) g = g.formats(["csc"])
g = g.to(device) g = g.to(device)
dataloader_device = device dataloader_device = device
self.g = g self.g = g
self.train_nid, self.val_nid, self.test_nid = train_nid, val_nid, test_nid self.train_nid, self.val_nid, self.test_nid = (
train_nid,
val_nid,
test_nid,
)
self.sampler = sampler self.sampler = sampler
self.device = dataloader_device self.device = dataloader_device
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
self.in_feats = g.ndata['features'].shape[1] self.in_feats = g.ndata["features"].shape[1]
self.n_classes = n_classes self.n_classes = n_classes
self.reverse_eids = reverse_eids self.reverse_eids = reverse_eids
def train_dataloader(self): def train_dataloader(self):
sampler = dgl.dataloading.as_edge_prediction_sampler( sampler = dgl.dataloading.as_edge_prediction_sampler(
self.sampler, exclude='reverse_id', self.sampler,
exclude="reverse_id",
reverse_eids=self.reverse_eids, reverse_eids=self.reverse_eids,
negative_sampler=NegativeSampler(self.g, args.num_negs, args.neg_share)) negative_sampler=NegativeSampler(
self.g, args.num_negs, args.neg_share
),
)
return dgl.dataloading.DataLoader( return dgl.dataloading.DataLoader(
self.g, self.g,
np.arange(self.g.num_edges()), np.arange(self.g.num_edges()),
...@@ -134,7 +158,8 @@ class DataModule(LightningDataModule): ...@@ -134,7 +158,8 @@ class DataModule(LightningDataModule):
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=self.num_workers) num_workers=self.num_workers,
)
def val_dataloader(self): def val_dataloader(self):
# Note that the validation data loader is a DataLoader # Note that the validation data loader is a DataLoader
...@@ -147,63 +172,92 @@ class DataModule(LightningDataModule): ...@@ -147,63 +172,92 @@ class DataModule(LightningDataModule):
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=self.num_workers) num_workers=self.num_workers,
)
class UnsupervisedClassification(Callback): class UnsupervisedClassification(Callback):
def on_validation_epoch_start(self, trainer, pl_module): def on_validation_epoch_start(self, trainer, pl_module):
self.val_outputs = [] self.val_outputs = []
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
self.val_outputs.append(outputs) self.val_outputs.append(outputs)
def on_validation_epoch_end(self, trainer, pl_module): def on_validation_epoch_end(self, trainer, pl_module):
node_emb = th.cat(self.val_outputs, 0) node_emb = th.cat(self.val_outputs, 0)
g = trainer.datamodule.g g = trainer.datamodule.g
labels = g.ndata['labels'] labels = g.ndata["labels"]
f1_micro, f1_macro = compute_acc( f1_micro, f1_macro = compute_acc(
node_emb, labels, trainer.datamodule.train_nid, node_emb,
trainer.datamodule.val_nid, trainer.datamodule.test_nid) labels,
pl_module.log('val_f1_micro', f1_micro) trainer.datamodule.train_nid,
trainer.datamodule.val_nid,
trainer.datamodule.test_nid,
)
pl_module.log("val_f1_micro", f1_micro)
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser("multi-gpu training") argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument("--gpu", type=int, default=0) argparser.add_argument("--gpu", type=int, default=0)
argparser.add_argument('--dataset', type=str, default='reddit') argparser.add_argument("--dataset", type=str, default="reddit")
argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument("--num-epochs", type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument("--num-hidden", type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2) argparser.add_argument("--num-layers", type=int, default=2)
argparser.add_argument('--num-negs', type=int, default=1) argparser.add_argument("--num-negs", type=int, default=1)
argparser.add_argument('--neg-share', default=False, action='store_true', argparser.add_argument(
help="sharing neg nodes for positive nodes") "--neg-share",
argparser.add_argument('--fan-out', type=str, default='10,25') default=False,
argparser.add_argument('--batch-size', type=int, default=10000) action="store_true",
argparser.add_argument('--log-every', type=int, default=20) help="sharing neg nodes for positive nodes",
argparser.add_argument('--eval-every', type=int, default=1000) )
argparser.add_argument('--lr', type=float, default=0.003) argparser.add_argument("--fan-out", type=str, default="10,25")
argparser.add_argument('--dropout', type=float, default=0.5) argparser.add_argument("--batch-size", type=int, default=10000)
argparser.add_argument('--num-workers', type=int, default=0, argparser.add_argument("--log-every", type=int, default=20)
help="Number of sampling processes. Use 0 for no extra process.") argparser.add_argument("--eval-every", type=int, default=1000)
argparser.add_argument("--lr", type=float, default=0.003)
argparser.add_argument("--dropout", type=float, default=0.5)
argparser.add_argument(
"--num-workers",
type=int,
default=0,
help="Number of sampling processes. Use 0 for no extra process.",
)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu) device = th.device("cuda:%d" % args.gpu)
else: else:
device = th.device('cpu') device = th.device("cpu")
datamodule = DataModule( datamodule = DataModule(
args.dataset, True, [int(_) for _ in args.fan_out.split(',')], args.dataset,
device, args.batch_size, args.num_workers) True,
[int(_) for _ in args.fan_out.split(",")],
device,
args.batch_size,
args.num_workers,
)
model = SAGELightning( model = SAGELightning(
datamodule.in_feats, args.num_hidden, datamodule.n_classes, args.num_layers, datamodule.in_feats,
F.relu, args.dropout, args.lr) args.num_hidden,
datamodule.n_classes,
args.num_layers,
F.relu,
args.dropout,
args.lr,
)
# Train # Train
unsupervised_callback = UnsupervisedClassification() unsupervised_callback = UnsupervisedClassification()
checkpoint_callback = ModelCheckpoint(monitor='val_f1_micro', save_top_k=1) checkpoint_callback = ModelCheckpoint(monitor="val_f1_micro", save_top_k=1)
trainer = Trainer(gpus=[args.gpu] if args.gpu != -1 else None, trainer = Trainer(
max_epochs=args.num_epochs, gpus=[args.gpu] if args.gpu != -1 else None,
val_check_interval=1000, max_epochs=args.num_epochs,
callbacks=[checkpoint_callback, unsupervised_callback], val_check_interval=1000,
num_sanity_val_steps=0) callbacks=[checkpoint_callback, unsupervised_callback],
num_sanity_val_steps=0,
)
trainer.fit(model, datamodule=datamodule) trainer.fit(model, datamodule=datamodule)
import glob import glob
import os import os
import dgl
import dgl.nn.pytorch as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -12,9 +15,6 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer ...@@ -12,9 +15,6 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy from torchmetrics import Accuracy
import dgl
import dgl.nn.pytorch as dglnn
class SAGE(LightningModule): class SAGE(LightningModule):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
......
import argparse
import dgl
import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler, as_edge_prediction_sampler, negative_sampler
import tqdm import tqdm
import argparse from dgl.dataloading import (
as_edge_prediction_sampler,
DataLoader,
MultiLayerFullNeighborSampler,
negative_sampler,
NeighborSampler,
)
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
def to_bidirected_with_reverse_mapping(g): def to_bidirected_with_reverse_mapping(g):
"""Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]`` """Makes a graph bidirectional, and returns a mapping array ``mapping`` where ``mapping[i]``
is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops. is the reverse edge of edge ID ``i``. Does not work with graphs that have self-loops.
""" """
g_simple, mapping = dgl.to_simple( g_simple, mapping = dgl.to_simple(
dgl.add_reverse_edges(g), return_counts='count', writeback_mapping=True) dgl.add_reverse_edges(g), return_counts="count", writeback_mapping=True
c = g_simple.edata['count'] )
c = g_simple.edata["count"]
num_edges = g.num_edges() num_edges = g.num_edges()
mapping_offset = torch.zeros(g_simple.num_edges() + 1, dtype=g_simple.idtype) mapping_offset = torch.zeros(
g_simple.num_edges() + 1, dtype=g_simple.idtype
)
mapping_offset[1:] = c.cumsum(0) mapping_offset[1:] = c.cumsum(0)
idx = mapping.argsort() idx = mapping.argsort()
idx_uniq = idx[mapping_offset[:-1]] idx_uniq = idx[mapping_offset[:-1]]
reverse_idx = torch.where(idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges) reverse_idx = torch.where(
idx_uniq >= num_edges, idx_uniq - num_edges, idx_uniq + num_edges
)
reverse_mapping = mapping[reverse_idx] reverse_mapping = mapping[reverse_idx]
# sanity check # sanity check
src1, dst1 = g_simple.edges() src1, dst1 = g_simple.edges()
...@@ -30,21 +43,23 @@ def to_bidirected_with_reverse_mapping(g): ...@@ -30,21 +43,23 @@ def to_bidirected_with_reverse_mapping(g):
assert torch.equal(src2, dst1) assert torch.equal(src2, dst1)
return g_simple, reverse_mapping return g_simple, reverse_mapping
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_size, hid_size): def __init__(self, in_size, hid_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean # three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
self.hid_size = hid_size self.hid_size = hid_size
self.predictor = nn.Sequential( self.predictor = nn.Sequential(
nn.Linear(hid_size, hid_size), nn.Linear(hid_size, hid_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(hid_size, hid_size), nn.Linear(hid_size, hid_size),
nn.ReLU(), nn.ReLU(),
nn.Linear(hid_size, 1)) nn.Linear(hid_size, 1),
)
def forward(self, pair_graph, neg_pair_graph, blocks, x): def forward(self, pair_graph, neg_pair_graph, blocks, x):
h = x h = x
...@@ -60,19 +75,31 @@ class SAGE(nn.Module): ...@@ -60,19 +75,31 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size): def inference(self, g, device, batch_size):
"""Layer-wise inference algorithm to compute GNN node embeddings.""" """Layer-wise inference algorithm to compute GNN node embeddings."""
feat = g.ndata['feat'] feat = g.ndata["feat"]
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
dataloader = DataLoader( dataloader = DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g,
batch_size=batch_size, shuffle=False, drop_last=False, torch.arange(g.num_nodes()).to(g.device),
num_workers=0) sampler,
buffer_device = torch.device('cpu') device=device,
pin_memory = (buffer_device != device) batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
buffer_device = torch.device("cpu")
pin_memory = buffer_device != device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.empty(g.num_nodes(), self.hid_size, device=buffer_device, y = torch.empty(
pin_memory=pin_memory) g.num_nodes(),
self.hid_size,
device=buffer_device,
pin_memory=pin_memory,
)
feat = feat.to(device) feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader, desc='Inference'): for input_nodes, output_nodes, blocks in tqdm.tqdm(
dataloader, desc="Inference"
):
x = feat[input_nodes] x = feat[input_nodes]
h = layer(blocks[0], x) h = layer(blocks[0], x)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
...@@ -81,49 +108,70 @@ class SAGE(nn.Module): ...@@ -81,49 +108,70 @@ class SAGE(nn.Module):
feat = y feat = y
return y return y
def compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500):
def compute_mrr(
model, evaluator, node_emb, src, dst, neg_dst, device, batch_size=500
):
"""Compute Mean Reciprocal Rank (MRR) in batches.""" """Compute Mean Reciprocal Rank (MRR) in batches."""
rr = torch.zeros(src.shape[0]) rr = torch.zeros(src.shape[0])
for start in tqdm.trange(0, src.shape[0], batch_size, desc='Evaluate'): for start in tqdm.trange(0, src.shape[0], batch_size, desc="Evaluate"):
end = min(start + batch_size, src.shape[0]) end = min(start + batch_size, src.shape[0])
all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1) all_dst = torch.cat([dst[start:end, None], neg_dst[start:end]], 1)
h_src = node_emb[src[start:end]][:, None, :].to(device) h_src = node_emb[src[start:end]][:, None, :].to(device)
h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device) h_dst = node_emb[all_dst.view(-1)].view(*all_dst.shape, -1).to(device)
pred = model.predictor(h_src*h_dst).squeeze(-1) pred = model.predictor(h_src * h_dst).squeeze(-1)
input_dict = {'y_pred_pos': pred[:,0], 'y_pred_neg': pred[:,1:]} input_dict = {"y_pred_pos": pred[:, 0], "y_pred_neg": pred[:, 1:]}
rr[start:end] = evaluator.eval(input_dict)['mrr_list'] rr[start:end] = evaluator.eval(input_dict)["mrr_list"]
return rr.mean() return rr.mean()
def evaluate(device, graph, edge_split, model, batch_size): def evaluate(device, graph, edge_split, model, batch_size):
model.eval() model.eval()
evaluator = Evaluator(name='ogbl-citation2') evaluator = Evaluator(name="ogbl-citation2")
with torch.no_grad(): with torch.no_grad():
node_emb = model.inference(graph, device, batch_size) node_emb = model.inference(graph, device, batch_size)
results = [] results = []
for split in ['valid', 'test']: for split in ["valid", "test"]:
src = edge_split[split]['source_node'].to(node_emb.device) src = edge_split[split]["source_node"].to(node_emb.device)
dst = edge_split[split]['target_node'].to(node_emb.device) dst = edge_split[split]["target_node"].to(node_emb.device)
neg_dst = edge_split[split]['target_node_neg'].to(node_emb.device) neg_dst = edge_split[split]["target_node_neg"].to(node_emb.device)
results.append(compute_mrr(model, evaluator, node_emb, src, dst, neg_dst, device)) results.append(
compute_mrr(
model, evaluator, node_emb, src, dst, neg_dst, device
)
)
return results return results
def train(args, device, g, reverse_eids, seed_edges, model): def train(args, device, g, reverse_eids, seed_edges, model):
# create sampler & dataloader # create sampler & dataloader
sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=['feat']) sampler = NeighborSampler([15, 10, 5], prefetch_node_feats=["feat"])
sampler = as_edge_prediction_sampler( sampler = as_edge_prediction_sampler(
sampler, exclude='reverse_id', reverse_eids=reverse_eids, sampler,
negative_sampler=negative_sampler.Uniform(1)) exclude="reverse_id",
use_uva = (args.mode == 'mixed') reverse_eids=reverse_eids,
negative_sampler=negative_sampler.Uniform(1),
)
use_uva = args.mode == "mixed"
dataloader = DataLoader( dataloader = DataLoader(
g, seed_edges, sampler, g,
device=device, batch_size=512, shuffle=True, seed_edges,
drop_last=False, num_workers=0, use_uva=use_uva) sampler,
device=device,
batch_size=512,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=0.0005) opt = torch.optim.Adam(model.parameters(), lr=0.0005)
for epoch in range(10): for epoch in range(10):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(dataloader): for it, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(
x = blocks[0].srcdata['feat'] dataloader
):
x = blocks[0].srcdata["feat"]
pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x) pos_score, neg_score = model(pair_graph, neg_pair_graph, blocks, x)
score = torch.cat([pos_score, neg_score]) score = torch.cat([pos_score, neg_score])
pos_label = torch.ones_like(pos_score) pos_label = torch.ones_like(pos_score)
...@@ -134,39 +182,51 @@ def train(args, device, g, reverse_eids, seed_edges, model): ...@@ -134,39 +182,51 @@ def train(args, device, g, reverse_eids, seed_edges, model):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
if (it+1) == 1000: break if (it + 1) == 1000:
print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it+1))) break
print("Epoch {:05d} | Loss {:.4f}".format(epoch, total_loss / (it + 1)))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'], parser.add_argument(
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, " "--mode",
"'puregpu' for pure-GPU training.") default="mixed",
choices=["cpu", "mixed", "puregpu"],
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
)
args = parser.parse_args() args = parser.parse_args()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
args.mode = 'cpu' args.mode = "cpu"
print(f'Training in {args.mode} mode.') print(f"Training in {args.mode} mode.")
# load and preprocess dataset # load and preprocess dataset
print('Loading data') print("Loading data")
dataset = DglLinkPropPredDataset('ogbl-citation2') dataset = DglLinkPropPredDataset("ogbl-citation2")
g = dataset[0] g = dataset[0]
g = g.to('cuda' if args.mode == 'puregpu' else 'cpu') g = g.to("cuda" if args.mode == "puregpu" else "cpu")
device = torch.device('cpu' if args.mode == 'cpu' else 'cuda') device = torch.device("cpu" if args.mode == "cpu" else "cuda")
g, reverse_eids = to_bidirected_with_reverse_mapping(g) g, reverse_eids = to_bidirected_with_reverse_mapping(g)
reverse_eids = reverse_eids.to(device) reverse_eids = reverse_eids.to(device)
seed_edges = torch.arange(g.num_edges()).to(device) seed_edges = torch.arange(g.num_edges()).to(device)
edge_split = dataset.get_edge_split() edge_split = dataset.get_edge_split()
# create GraphSAGE model # create GraphSAGE model
in_size = g.ndata['feat'].shape[1] in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256).to(device) model = SAGE(in_size, 256).to(device)
# model training # model training
print('Training...') print("Training...")
train(args, device, g, reverse_eids, seed_edges, model) train(args, device, g, reverse_eids, seed_edges, model)
# validate/test the model # validate/test the model
print('Validation/Testing...') print("Validation/Testing...")
valid_mrr, test_mrr = evaluate(device, g, edge_split, model, batch_size=1000) valid_mrr, test_mrr = evaluate(
print('Validation MRR {:.4f}, Test MRR {:.4f}'.format(valid_mrr.item(),test_mrr.item())) device, g, edge_split, model, batch_size=1000
)
print(
"Validation MRR {:.4f}, Test MRR {:.4f}".format(
valid_mrr.item(), test_mrr.item()
)
)
import torch as th
import dgl import dgl
import torch as th
def load_reddit(self_loop=True): def load_reddit(self_loop=True):
......
import argparse
import dgl
import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import dgl import tqdm
import dgl.nn as dglnn
from dgl.data import AsNodePredDataset from dgl.data import AsNodePredDataset
from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler from dgl.dataloading import (
DataLoader,
MultiLayerFullNeighborSampler,
NeighborSampler,
)
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size): def __init__(self, in_size, hid_size, out_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# three-layer GraphSAGE-mean # three-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean')) self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.hid_size = hid_size self.hid_size = hid_size
self.out_size = out_size self.out_size = out_size
...@@ -33,76 +39,108 @@ class SAGE(nn.Module): ...@@ -33,76 +39,108 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size): def inference(self, g, device, batch_size):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
feat = g.ndata['feat'] feat = g.ndata["feat"]
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["feat"])
dataloader = DataLoader( dataloader = DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g,
batch_size=batch_size, shuffle=False, drop_last=False, torch.arange(g.num_nodes()).to(g.device),
num_workers=0) sampler,
buffer_device = torch.device('cpu') device=device,
pin_memory = (buffer_device != device) batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
)
buffer_device = torch.device("cpu")
pin_memory = buffer_device != device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.empty( y = torch.empty(
g.num_nodes(), self.hid_size if l != len(self.layers) - 1 else self.out_size, g.num_nodes(),
device=buffer_device, pin_memory=pin_memory) self.hid_size if l != len(self.layers) - 1 else self.out_size,
device=buffer_device,
pin_memory=pin_memory,
)
feat = feat.to(device) feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = feat[input_nodes] x = feat[input_nodes]
h = layer(blocks[0], x) # len(blocks) = 1 h = layer(blocks[0], x) # len(blocks) = 1
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
h = self.dropout(h) h = self.dropout(h)
# by design, our output nodes are contiguous # by design, our output nodes are contiguous
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device) y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
feat = y feat = y
return y return y
def evaluate(model, graph, dataloader): def evaluate(model, graph, dataloader):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
with torch.no_grad(): with torch.no_grad():
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata['label']) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(torch.cat(y_hats), torch.cat(ys))
def layerwise_infer(device, graph, nid, model, batch_size): def layerwise_infer(device, graph, nid, model, batch_size):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, batch_size) # pred in buffer_device pred = model.inference(
graph, device, batch_size
) # pred in buffer_device
pred = pred[nid] pred = pred[nid]
label = graph.ndata['label'][nid].to(pred.device) label = graph.ndata["label"][nid].to(pred.device)
return MF.accuracy(pred, label) return MF.accuracy(pred, label)
def train(args, device, g, dataset, model): def train(args, device, g, dataset, model):
# create sampler & dataloader # create sampler & dataloader
train_idx = dataset.train_idx.to(device) train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device) val_idx = dataset.val_idx.to(device)
sampler = NeighborSampler([10, 10, 10], # fanout for [layer-0, layer-1, layer-2] sampler = NeighborSampler(
prefetch_node_feats=['feat'], [10, 10, 10], # fanout for [layer-0, layer-1, layer-2]
prefetch_labels=['label']) prefetch_node_feats=["feat"],
use_uva = (args.mode == 'mixed') prefetch_labels=["label"],
train_dataloader = DataLoader(g, train_idx, sampler, device=device, )
batch_size=1024, shuffle=True, use_uva = args.mode == "mixed"
drop_last=False, num_workers=0, train_dataloader = DataLoader(
use_uva=use_uva) g,
train_idx,
val_dataloader = DataLoader(g, val_idx, sampler, device=device, sampler,
batch_size=1024, shuffle=True, device=device,
drop_last=False, num_workers=0, batch_size=1024,
use_uva=use_uva) shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)
val_dataloader = DataLoader(
g,
val_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=use_uva,
)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
for epoch in range(10): for epoch in range(10):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(
x = blocks[0].srcdata['feat'] train_dataloader
y = blocks[-1].dstdata['label'] ):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
...@@ -110,36 +148,44 @@ def train(args, device, g, dataset, model): ...@@ -110,36 +148,44 @@ def train(args, device, g, dataset, model):
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, g, val_dataloader) acc = evaluate(model, g, val_dataloader)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} " print(
.format(epoch, total_loss / (it+1), acc.item())) "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item()
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--mode", default='mixed', choices=['cpu', 'mixed', 'puregpu'], parser.add_argument(
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, " "--mode",
"'puregpu' for pure-GPU training.") default="mixed",
choices=["cpu", "mixed", "puregpu"],
help="Training mode. 'cpu' for CPU training, 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training.",
)
args = parser.parse_args() args = parser.parse_args()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
args.mode = 'cpu' args.mode = "cpu"
print(f'Training in {args.mode} mode.') print(f"Training in {args.mode} mode.")
# load and preprocess dataset # load and preprocess dataset
print('Loading data') print("Loading data")
dataset = AsNodePredDataset(DglNodePropPredDataset('ogbn-products')) dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0] g = dataset[0]
g = g.to('cuda' if args.mode == 'puregpu' else 'cpu') g = g.to("cuda" if args.mode == "puregpu" else "cpu")
device = torch.device('cpu' if args.mode == 'cpu' else 'cuda') device = torch.device("cpu" if args.mode == "cpu" else "cuda")
# create GraphSAGE model # create GraphSAGE model
in_size = g.ndata['feat'].shape[1] in_size = g.ndata["feat"].shape[1]
out_size = dataset.num_classes out_size = dataset.num_classes
model = SAGE(in_size, 256, out_size).to(device) model = SAGE(in_size, 256, out_size).to(device)
# model training # model training
print('Training...') print("Training...")
train(args, device, g, dataset, model) train(args, device, g, dataset, model)
# test the model # test the model
print('Testing...') print("Testing...")
acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096) acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
import argparse import argparse
import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.nn as dglnn
from dgl import AddSelfLoop from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
......
import dgl.function as fn
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
class GCNLayer(nn.Module): class GCNLayer(nn.Module):
def __init__( def __init__(
......
...@@ -3,14 +3,14 @@ import os ...@@ -3,14 +3,14 @@ import os
import random import random
import time import time
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import scipy import scipy
import torch as th import torch as th
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn
from dgl.sampling import pack_traces, random_walk from dgl.sampling import pack_traces, random_walk
from torch.utils.data import DataLoader
# The base class of sampler # The base class of sampler
...@@ -123,7 +123,6 @@ class SAINTSampler: ...@@ -123,7 +123,6 @@ class SAINTSampler:
t = time.perf_counter() t = time.perf_counter()
for num_nodes, subgraphs_nids, subgraphs_eids in loader: for num_nodes, subgraphs_nids, subgraphs_eids in loader:
self.subgraphs.extend(subgraphs_nids) self.subgraphs.extend(subgraphs_nids)
sampled_nodes += num_nodes sampled_nodes += num_nodes
...@@ -214,7 +213,6 @@ class SAINTSampler: ...@@ -214,7 +213,6 @@ class SAINTSampler:
raise NotImplementedError raise NotImplementedError
def __compute_norm__(self): def __compute_norm__(self):
self.node_counter[self.node_counter == 0] = 1 self.node_counter[self.node_counter == 0] = 1
self.edge_counter[self.edge_counter == 0] = 1 self.edge_counter[self.edge_counter == 0] = 1
...@@ -231,7 +229,6 @@ class SAINTSampler: ...@@ -231,7 +229,6 @@ class SAINTSampler:
return aggr_norm.numpy(), loss_norm.numpy() return aggr_norm.numpy(), loss_norm.numpy()
def __compute_degree_norm(self): def __compute_degree_norm(self):
self.train_g.ndata[ self.train_g.ndata[
"train_D_norm" "train_D_norm"
] = 1.0 / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1) ] = 1.0 / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)
......
...@@ -9,7 +9,7 @@ from config import CONFIG ...@@ -9,7 +9,7 @@ from config import CONFIG
from modules import GCNNet from modules import GCNNet
from sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler from sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils import Logger, calc_f1, evaluate, load_data, save_log_dir from utils import calc_f1, evaluate, load_data, Logger, save_log_dir
def main(args, task): def main(args, task):
......
...@@ -2,14 +2,14 @@ import json ...@@ -2,14 +2,14 @@ import json
import os import os
from functools import namedtuple from functools import namedtuple
import dgl
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
import torch import torch
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
import dgl
class Logger(object): class Logger(object):
"""A custom logger to log stdout to a logging file.""" """A custom logger to log stdout to a logging file."""
......
import copy import copy
import os import os
import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import dgl
def build_dense_graph(n_particles): def build_dense_graph(n_particles):
g = nx.complete_graph(n_particles) g = nx.complete_graph(n_particles)
......
import copy import copy
from functools import partial from functools import partial
import torch
import torch.nn as nn
from torch.nn import functional as F
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import dgl.nn as dglnn import dgl.nn as dglnn
import torch
import torch.nn as nn
from torch.nn import functional as F
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, in_feats, out_feats, num_layers=2, hidden=128): def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment