"examples/pytorch/vscode:/vscode.git/clone" did not exist on "3fef5d27d32ef2cec0871e6676f5c09c7e91fe02"
Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
...@@ -9,9 +9,9 @@ import numpy as np ...@@ -9,9 +9,9 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model import EntityClassify_HeteroAPI
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from model import EntityClassify_HeteroAPI
def main(args): def main(args):
......
...@@ -6,13 +6,13 @@ import argparse ...@@ -6,13 +6,13 @@ import argparse
import itertools import itertools
import time import time
import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from model import EntityClassify, RelGraphEmbed
import dgl
from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
from model import EntityClassify, RelGraphEmbed
def extract_embed(node_embed, input_nodes): def extract_embed(node_embed, input_nodes):
......
"""RGCN layer implementation""" """RGCN layer implementation"""
from collections import defaultdict from collections import defaultdict
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
import dgl
import dgl.function as fn
import dgl.nn as dglnn
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer. r"""Relational graph convolution layer.
......
...@@ -5,9 +5,9 @@ from functools import partial ...@@ -5,9 +5,9 @@ from functools import partial
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
from entity_classify import EntityClassify
from dgl.data.rdf import AIFB, AM, BGS, MUTAG from dgl.data.rdf import AIFB, AM, BGS, MUTAG
from entity_classify import EntityClassify
def main(args): def main(args):
......
import argparse
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
import argparse from torchmetrics.functional import accuracy
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels): def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__() super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN # two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis', self.conv1 = RelGraphConv(
num_bases=num_rels, self_loop=False) h_dim,
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis', h_dim,
num_bases=num_rels, self_loop=False) num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
self.conv2 = RelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
def forward(self, g): def forward(self, g):
x = self.emb.weight x = self.emb.weight
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])) h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm']) h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
return h return h
def evaluate(g, target_idx, labels, test_mask, model): def evaluate(g, target_idx, labels, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval() model.eval()
...@@ -31,6 +46,7 @@ def evaluate(g, target_idx, labels, test_mask, model): ...@@ -31,6 +46,7 @@ def evaluate(g, target_idx, labels, test_mask, model):
logits = logits[target_idx] logits = logits[target_idx]
return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
def train(g, target_idx, labels, train_mask, model): def train(g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
...@@ -45,48 +61,60 @@ def train(g, target_idx, labels, train_mask, model): ...@@ -45,48 +61,60 @@ def train(g, target_idx, labels, train_mask, model):
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item() acc = accuracy(
print("Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} " logits[train_idx].argmax(dim=1), labels[train_idx]
. format(epoch, loss.item(), acc)) ).item()
print(
if __name__ == '__main__': "Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format(
parser = argparse.ArgumentParser(description='RGCN for entity classification') epoch, loss.item(), acc
parser.add_argument("--dataset", type=str, default="aifb", )
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').") )
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="RGCN for entity classification"
)
parser.add_argument(
"--dataset",
type=str,
default="aifb",
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
)
args = parser.parse_args() args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Training with DGL built-in RGCN module.') print(f"Training with DGL built-in RGCN module.")
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'aifb': if args.dataset == "aifb":
data = AIFBDataset() data = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
data = MUTAGDataset() data = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
data = BGSDataset() data = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
data = AMDataset() data = AMDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
g = g.int().to(device) g = g.int().to(device)
num_rels = len(g.canonical_etypes) num_rels = len(g.canonical_etypes)
category = data.predict_category category = data.predict_category
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
# calculate normalization weight for each edge, and find target category and node id # calculate normalization weight for each edge, and find target category and node id
for cetype in g.canonical_etypes: for cetype in g.canonical_etypes:
g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1) g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category) category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g, edata=['norm']) g = dgl.to_homogeneous(g, edata=["norm"])
node_ids = torch.arange(g.num_nodes()).to(device) node_ids = torch.arange(g.num_nodes()).to(device)
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(g, target_idx, labels, train_mask, model) train(g, target_idx, labels, train_mask, model)
acc = evaluate(g, target_idx, labels, test_mask, model) acc = evaluate(g, target_idx, labels, test_mask, model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
import argparse
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
import dgl from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
import argparse from torchmetrics.functional import accuracy
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels): def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__() super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN # two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis', self.conv1 = RelGraphConv(
num_bases=num_rels, self_loop=False) h_dim,
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis', h_dim,
num_bases=num_rels, self_loop=False) num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
self.conv2 = RelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
def forward(self, g): def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID]) x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm'])) h = F.relu(
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm']) self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
)
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
return h return h
def evaluate(model, label, dataloader, inv_target): def evaluate(model, label, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
...@@ -32,13 +49,14 @@ def evaluate(model, label, dataloader, inv_target): ...@@ -32,13 +49,14 @@ def evaluate(model, label, dataloader, inv_target):
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
output_nodes = inv_target[output_nodes] output_nodes = inv_target[output_nodes]
for block in blocks: for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1) block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks) logits = model(blocks)
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits) eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item() return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()
def train(device, g, target_idx, labels, train_mask, model): def train(device, g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
...@@ -47,18 +65,30 @@ def train(device, g, target_idx, labels, train_mask, model): ...@@ -47,18 +65,30 @@ def train(device, g, target_idx, labels, train_mask, model):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# construct sampler and dataloader # construct sampler and dataloader
sampler = MultiLayerNeighborSampler([4, 4]) sampler = MultiLayerNeighborSampler([4, 4])
train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, train_loader = DataLoader(
batch_size=100, shuffle=True) g,
target_idx[train_idx],
sampler,
device=device,
batch_size=100,
shuffle=True,
)
# no separate validation subset, use train index instead for validation # no separate validation subset, use train index instead for validation
val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, val_loader = DataLoader(
batch_size=100, shuffle=False) g,
target_idx[train_idx],
sampler,
device=device,
batch_size=100,
shuffle=False,
)
for epoch in range(50): for epoch in range(50):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
output_nodes = inv_target[output_nodes] output_nodes = inv_target[output_nodes]
for block in blocks: for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1) block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks) logits = model(blocks)
loss = loss_fcn(logits, labels[output_nodes]) loss = loss_fcn(logits, labels[output_nodes])
optimizer.zero_grad() optimizer.zero_grad()
...@@ -66,55 +96,75 @@ def train(device, g, target_idx, labels, train_mask, model): ...@@ -66,55 +96,75 @@ def train(device, g, target_idx, labels, train_mask, model):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, labels, val_loader, inv_target) acc = evaluate(model, labels, val_loader, inv_target)
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} " print(
. format(epoch, total_loss / (it+1), acc)) "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling') parser = argparse.ArgumentParser(
parser.add_argument("--dataset", type=str, default="aifb", description="RGCN for entity classification with sampling"
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').") )
parser.add_argument(
"--dataset",
type=str,
default="aifb",
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
)
args = parser.parse_args() args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Training with DGL built-in RGCN module with sampling.') print(f"Training with DGL built-in RGCN module with sampling.")
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'aifb': if args.dataset == "aifb":
data = AIFBDataset() data = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
data = MUTAGDataset() data = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
data = BGSDataset() data = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
data = AMDataset() data = AMDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
num_rels = len(g.canonical_etypes) num_rels = len(g.canonical_etypes)
category = data.predict_category category = data.predict_category
labels = g.nodes[category].data.pop('labels').to(device) labels = g.nodes[category].data.pop("labels").to(device)
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
# find target category and node id # find target category and node id
category_id = g.ntypes.index(category) category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g) g = dgl.to_homogeneous(g)
node_ids = torch.arange(g.num_nodes()) node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# rename the fields as they can be changed by DataLoader # rename the fields as they can be changed by DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE) g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID) g.ndata["type_id"] = g.ndata.pop(dgl.NID)
# find the mapping (inv_target) from global node IDs to type-specific node IDs # find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device) inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device) inv_target[target_idx] = torch.arange(
0, target_idx.shape[0], dtype=inv_target.dtype
).to(device)
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model) train(device, g, target_idx, labels, train_mask, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors test_sampler = MultiLayerNeighborSampler(
test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, [-1, -1]
batch_size=32, shuffle=False) ) # -1 for sampling all neighbors
test_loader = DataLoader(
g,
target_idx[test_idx],
test_sampler,
device=device,
batch_size=32,
shuffle=False,
)
acc = evaluate(model, labels, test_loader, inv_target) acc = evaluate(model, labels, test_loader, inv_target)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
import argparse
import os import os
import dgl
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
import torch.multiprocessing as mp from dgl.dataloading import DataLoader, MultiLayerNeighborSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
import argparse from torch.nn.parallel import DistributedDataParallel
from torchmetrics.functional import accuracy
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels): def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__() super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN # two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis', self.conv1 = RelGraphConv(
num_bases=num_rels, self_loop=False) h_dim,
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis', h_dim,
num_bases=num_rels, self_loop=False) num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
self.conv2 = RelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=num_rels,
self_loop=False,
)
def forward(self, g): def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID]) x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm'])) h = F.relu(
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm']) self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
)
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
return h return h
def evaluate(model, labels, dataloader, inv_target): def evaluate(model, labels, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
...@@ -36,34 +53,51 @@ def evaluate(model, labels, dataloader, inv_target): ...@@ -36,34 +53,51 @@ def evaluate(model, labels, dataloader, inv_target):
for input_nodes, output_nodes, blocks in dataloader: for input_nodes, output_nodes, blocks in dataloader:
output_nodes = inv_target[output_nodes] output_nodes = inv_target[output_nodes]
for block in blocks: for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1) block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks) logits = model(blocks)
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits) eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
num_seeds = len(eval_seeds) num_seeds = len(eval_seeds)
loc_sum = accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()) * float(num_seeds) loc_sum = accuracy(
eval_logits.argmax(dim=1), labels[eval_seeds].cpu()
) * float(num_seeds)
return torch.tensor([loc_sum.item(), float(num_seeds)]) return torch.tensor([loc_sum.item(), float(num_seeds)])
def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
# define loss function and optimizer # define loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# construct sampler and dataloader # construct sampler and dataloader
sampler = MultiLayerNeighborSampler([4, 4]) sampler = MultiLayerNeighborSampler([4, 4])
train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, train_loader = DataLoader(
batch_size=100, shuffle=True, use_ddp=True) g,
target_idx[train_idx],
sampler,
device=device,
batch_size=100,
shuffle=True,
use_ddp=True,
)
# no separate validation subset, use train index instead for validation # no separate validation subset, use train index instead for validation
val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device, val_loader = DataLoader(
batch_size=100, shuffle=False, use_ddp=True) g,
target_idx[train_idx],
sampler,
device=device,
batch_size=100,
shuffle=False,
use_ddp=True,
)
for epoch in range(50): for epoch in range(50):
model.train() model.train()
total_loss = 0 total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
output_nodes = inv_target[output_nodes] output_nodes = inv_target[output_nodes]
for block in blocks: for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1) block.edata["norm"] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks) logits = model(blocks)
loss = loss_fcn(logits, labels[output_nodes]) loss = loss_fcn(logits, labels[output_nodes])
optimizer.zero_grad() optimizer.zero_grad()
...@@ -71,89 +105,142 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): ...@@ -71,89 +105,142 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
# torchmetric accuracy defined as num_correct_labels / num_train_nodes # torchmetric accuracy defined as num_correct_labels / num_train_nodes
# loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes] # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(device) loc_acc_split = evaluate(model, labels, val_loader, inv_target).to(
device
)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if (proc_id == 0): if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} " print(
. format(epoch, total_loss / (it+1), acc.item())) "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item()
)
)
def run(proc_id, nprocs, devices, g, data): def run(proc_id, nprocs, devices, g, data):
# find corresponding device for my rank # find corresponding device for my rank
device = devices[proc_id] device = devices[proc_id]
torch.cuda.set_device(device) torch.cuda.set_device(device)
# initialize process group and unpack data for sub-processes # initialize process group and unpack data for sub-processes
dist.init_process_group(backend="nccl", init_method='tcp://127.0.0.1:12345', world_size=nprocs, rank=proc_id) dist.init_process_group(
num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target = data backend="nccl",
init_method="tcp://127.0.0.1:12345",
world_size=nprocs,
rank=proc_id,
)
(
num_rels,
num_classes,
labels,
train_idx,
test_idx,
target_idx,
inv_target,
) = data
labels = labels.to(device) labels = labels.to(device)
inv_target = inv_target.to(device) inv_target = inv_target.to(device)
# create RGCN model (distributed) # create RGCN model (distributed)
in_size = g.num_nodes() in_size = g.num_nodes()
out_size = num_classes out_size = num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, out_size, num_rels).to(device)
model = DistributedDataParallel(model, device_ids=[device], output_device=device) model = DistributedDataParallel(
model, device_ids=[device], output_device=device
)
# training + testing # training + testing
train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model) train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model)
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors test_sampler = MultiLayerNeighborSampler(
test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device, [-1, -1]
batch_size=32, shuffle=False, use_ddp=True) ) # -1 for sampling all neighbors
test_loader = DataLoader(
g,
target_idx[test_idx],
test_sampler,
device=device,
batch_size=32,
shuffle=False,
use_ddp=True,
)
loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device) loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if (proc_id == 0): if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
# cleanup process group # cleanup process group
dist.destroy_process_group() dist.destroy_process_group()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling (multi-gpu)') if __name__ == "__main__":
parser.add_argument("--dataset", type=str, default="aifb", parser = argparse.ArgumentParser(
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').") description="RGCN for entity classification with sampling (multi-gpu)"
parser.add_argument("--gpu", type=str, default='0', )
help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training," parser.add_argument(
" e.g., 0,1,2,3.") "--dataset",
type=str,
default="aifb",
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').",
)
parser.add_argument(
"--gpu",
type=str,
default="0",
help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training,"
" e.g., 0,1,2,3.",
)
args = parser.parse_args() args = parser.parse_args()
devices = list(map(int, args.gpu.split(','))) devices = list(map(int, args.gpu.split(",")))
nprocs = len(devices) nprocs = len(devices)
print(f'Training with DGL built-in RGCN module with sampling using', nprocs, f'GPU(s)') print(
f"Training with DGL built-in RGCN module with sampling using",
nprocs,
f"GPU(s)",
)
# load and preprocess dataset at master(parent) process # load and preprocess dataset at master(parent) process
if args.dataset == 'aifb': if args.dataset == "aifb":
data = AIFBDataset() data = AIFBDataset()
elif args.dataset == 'mutag': elif args.dataset == "mutag":
data = MUTAGDataset() data = MUTAGDataset()
elif args.dataset == 'bgs': elif args.dataset == "bgs":
data = BGSDataset() data = BGSDataset()
elif args.dataset == 'am': elif args.dataset == "am":
data = AMDataset() data = AMDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
num_rels = len(g.canonical_etypes) num_rels = len(g.canonical_etypes)
category = data.predict_category category = data.predict_category
labels = g.nodes[category].data.pop('labels') labels = g.nodes[category].data.pop("labels")
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
# find target category and node id # find target category and node id
category_id = g.ntypes.index(category) category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g) g = dgl.to_homogeneous(g)
node_ids = torch.arange(g.num_nodes()) node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# rename the fields as they can be changed by DataLoader # rename the fields as they can be changed by DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE) g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID) g.ndata["type_id"] = g.ndata.pop(dgl.NID)
# find the mapping (inv_target) from global node IDs to type-specific node IDs # find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64) inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype) inv_target[target_idx] = torch.arange(
0, target_idx.shape[0], dtype=inv_target.dtype
)
# avoid creating certain graph formats and train/test indexes in each sub-process to save momory # avoid creating certain graph formats and train/test indexes in each sub-process to save momory
g.create_formats_() g.create_formats_()
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
# thread limiting to avoid resource competition # thread limiting to avoid resource competition
os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // nprocs) os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
data = num_rels, data.num_classes, labels, train_idx, test_idx, target_idx, inv_target data = (
num_rels,
data.num_classes,
labels,
train_idx,
test_idx,
target_idx,
inv_target,
)
mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs) mp.spawn(run, args=(nprocs, devices, g, data), nprocs=nprocs)
import dgl import dgl
import torch as th import torch as th
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset from dgl.data.rdf import AIFBDataset, AMDataset, BGSDataset, MUTAGDataset
def load_data(data_name, get_norm=False, inv_target=False): def load_data(data_name, get_norm=False, inv_target=False):
if data_name == 'aifb': if data_name == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif data_name == 'mutag': elif data_name == "mutag":
dataset = MUTAGDataset() dataset = MUTAGDataset()
elif data_name == 'bgs': elif data_name == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
else: else:
dataset = AMDataset() dataset = AMDataset()
...@@ -19,9 +20,9 @@ def load_data(data_name, get_norm=False, inv_target=False): ...@@ -19,9 +20,9 @@ def load_data(data_name, get_norm=False, inv_target=False):
num_rels = len(hg.canonical_etypes) num_rels = len(hg.canonical_etypes)
category = dataset.predict_category category = dataset.predict_category
num_classes = dataset.num_classes num_classes = dataset.num_classes
labels = hg.nodes[category].data.pop('labels') labels = hg.nodes[category].data.pop("labels")
train_mask = hg.nodes[category].data.pop('train_mask') train_mask = hg.nodes[category].data.pop("train_mask")
test_mask = hg.nodes[category].data.pop('test_mask') test_mask = hg.nodes[category].data.pop("test_mask")
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze() train_idx = th.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = th.nonzero(test_mask, as_tuple=False).squeeze() test_idx = th.nonzero(test_mask, as_tuple=False).squeeze()
...@@ -29,8 +30,10 @@ def load_data(data_name, get_norm=False, inv_target=False): ...@@ -29,8 +30,10 @@ def load_data(data_name, get_norm=False, inv_target=False):
# Calculate normalization weight for each edge, # Calculate normalization weight for each edge,
# 1. / d, d is the degree of the destination node # 1. / d, d is the degree of the destination node
for cetype in hg.canonical_etypes: for cetype in hg.canonical_etypes:
hg.edges[cetype].data['norm'] = dgl.norm_by_dst(hg, cetype).unsqueeze(1) hg.edges[cetype].data["norm"] = dgl.norm_by_dst(
edata = ['norm'] hg, cetype
).unsqueeze(1)
edata = ["norm"]
else: else:
edata = None edata = None
...@@ -39,20 +42,30 @@ def load_data(data_name, get_norm=False, inv_target=False): ...@@ -39,20 +42,30 @@ def load_data(data_name, get_norm=False, inv_target=False):
g = dgl.to_homogeneous(hg, edata=edata) g = dgl.to_homogeneous(hg, edata=edata)
# Rename the fields as they can be changed by for example DataLoader # Rename the fields as they can be changed by for example DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE) g.ndata["ntype"] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID) g.ndata["type_id"] = g.ndata.pop(dgl.NID)
node_ids = th.arange(g.num_nodes()) node_ids = th.arange(g.num_nodes())
# find out the target node ids in g # find out the target node ids in g
loc = (g.ndata['ntype'] == category_id) loc = g.ndata["ntype"] == category_id
target_idx = node_ids[loc] target_idx = node_ids[loc]
if inv_target: if inv_target:
# Map global node IDs to type-specific node IDs. This is required for # Map global node IDs to type-specific node IDs. This is required for
# looking up type-specific labels in a minibatch # looking up type-specific labels in a minibatch
inv_target = th.empty((g.num_nodes(),), dtype=th.int64) inv_target = th.empty((g.num_nodes(),), dtype=th.int64)
inv_target[target_idx] = th.arange(0, target_idx.shape[0], inv_target[target_idx] = th.arange(
dtype=inv_target.dtype) 0, target_idx.shape[0], dtype=inv_target.dtype
return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx, inv_target )
return (
g,
num_rels,
num_classes,
labels,
train_idx,
test_idx,
target_idx,
inv_target,
)
else: else:
return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx return g, num_rels, num_classes, labels, train_idx, test_idx, target_idx
...@@ -7,28 +7,30 @@ Difference compared to tkipf/relation-gcn ...@@ -7,28 +7,30 @@ Difference compared to tkipf/relation-gcn
* remove nodes that won't be touched * remove nodes that won't be touched
""" """
import argparse import argparse
import gc, os
import itertools import itertools
import numpy as np
import time import time
import os, gc
os.environ['DGLBACKEND']='pytorch'
import numpy as np
os.environ["DGLBACKEND"] = "pytorch"
from functools import partial
import dgl
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dgl
from dgl import nn as dglnn
from dgl import DGLGraph
from dgl.distributed import DistDataLoader
from functools import partial
import tqdm import tqdm
from dgl import DGLGraph, nn as dglnn
from dgl.distributed import DistDataLoader
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
...@@ -54,17 +56,20 @@ class RelGraphConvLayer(nn.Module): ...@@ -54,17 +56,20 @@ class RelGraphConvLayer(nn.Module):
dropout : float, optional dropout : float, optional
Dropout rate. Default: 0.0 Dropout rate. Default: 0.0
""" """
def __init__(self,
in_feat, def __init__(
out_feat, self,
rel_names, in_feat,
num_bases, out_feat,
*, rel_names,
weight=True, num_bases,
bias=True, *,
activation=None, weight=True,
self_loop=False, bias=True,
dropout=0.0): activation=None,
self_loop=False,
dropout=0.0
):
super(RelGraphConvLayer, self).__init__() super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -74,19 +79,29 @@ class RelGraphConvLayer(nn.Module): ...@@ -74,19 +79,29 @@ class RelGraphConvLayer(nn.Module):
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({ self.conv = dglnn.HeteroGraphConv(
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False) {
rel: dglnn.GraphConv(
in_feat, out_feat, norm="right", weight=False, bias=False
)
for rel in rel_names for rel in rel_names
}) }
)
self.use_weight = weight self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight: if self.use_weight:
if self.use_basis: if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names)) self.basis = dglnn.WeightBasis(
(in_feat, out_feat), num_bases, len(self.rel_names)
)
else: else:
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat)) self.weight = nn.Parameter(
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) th.Tensor(len(self.rel_names), in_feat, out_feat)
)
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
# bias # bias
if bias: if bias:
...@@ -96,8 +111,9 @@ class RelGraphConvLayer(nn.Module): ...@@ -96,8 +111,9 @@ class RelGraphConvLayer(nn.Module):
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.loop_weight, gain=nn.init.calculate_gain("relu")
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -117,14 +133,18 @@ class RelGraphConvLayer(nn.Module): ...@@ -117,14 +133,18 @@ class RelGraphConvLayer(nn.Module):
g = g.local_var() g = g.local_var()
if self.use_weight: if self.use_weight:
weight = self.basis() if self.use_basis else self.weight weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)} wdict = {
for i, w in enumerate(th.split(weight, 1, dim=0))} self.rel_names[i]: {"weight": w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))
}
else: else:
wdict = {} wdict = {}
if g.is_block: if g.is_block:
inputs_src = inputs inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} inputs_dst = {
k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
}
else: else:
inputs_src = inputs_dst = inputs inputs_src = inputs_dst = inputs
...@@ -138,10 +158,12 @@ class RelGraphConvLayer(nn.Module): ...@@ -138,10 +158,12 @@ class RelGraphConvLayer(nn.Module):
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return self.dropout(h) return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class EntityClassify(nn.Module): class EntityClassify(nn.Module):
""" Entity classification class for RGCN """Entity classification class for RGCN
Parameters Parameters
---------- ----------
device : int device : int
...@@ -163,16 +185,19 @@ class EntityClassify(nn.Module): ...@@ -163,16 +185,19 @@ class EntityClassify(nn.Module):
use_self_loop : bool use_self_loop : bool
Use self loop if True, default False. Use self loop if True, default False.
""" """
def __init__(self,
device, def __init__(
h_dim, self,
out_dim, device,
rel_names, h_dim,
num_bases=None, out_dim,
num_hidden_layers=1, rel_names,
dropout=0, num_bases=None,
use_self_loop=False, num_hidden_layers=1,
layer_norm=False): dropout=0,
use_self_loop=False,
layer_norm=False,
):
super(EntityClassify, self).__init__() super(EntityClassify, self).__init__()
self.device = device self.device = device
self.h_dim = h_dim self.h_dim = h_dim
...@@ -185,20 +210,41 @@ class EntityClassify(nn.Module): ...@@ -185,20 +210,41 @@ class EntityClassify(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.h_dim, rel_names, RelGraphConvLayer(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout)) self.h_dim,
rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2h # h2h
for idx in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.h_dim, rel_names, RelGraphConvLayer(
self.num_bases, activation=F.relu, self_loop=self.use_self_loop, self.h_dim,
dropout=self.dropout)) self.h_dim,
rel_names,
self.num_bases,
activation=F.relu,
self_loop=self.use_self_loop,
dropout=self.dropout,
)
)
# h2o # h2o
self.layers.append(RelGraphConvLayer( self.layers.append(
self.h_dim, self.out_dim, rel_names, RelGraphConvLayer(
self.num_bases, activation=None, self_loop=self.use_self_loop)) self.h_dim,
self.out_dim,
rel_names,
self.num_bases,
activation=None,
self_loop=self.use_self_loop,
)
)
def forward(self, blocks, feats, norm=None): def forward(self, blocks, feats, norm=None):
if blocks is None: if blocks is None:
...@@ -210,11 +256,13 @@ class EntityClassify(nn.Module): ...@@ -210,11 +256,13 @@ class EntityClassify(nn.Module):
h = layer(block, h) h = layer(block, h)
return h return h
def init_emb(shape, dtype): def init_emb(shape, dtype):
arr = th.zeros(shape, dtype=dtype) arr = th.zeros(shape, dtype=dtype)
nn.init.uniform_(arr, -1.0, 1.0) nn.init.uniform_(arr, -1.0, 1.0)
return arr return arr
class DistEmbedLayer(nn.Module): class DistEmbedLayer(nn.Module):
r"""Embedding layer for featureless heterograph. r"""Embedding layer for featureless heterograph.
Parameters Parameters
...@@ -234,14 +282,17 @@ class DistEmbedLayer(nn.Module): ...@@ -234,14 +282,17 @@ class DistEmbedLayer(nn.Module):
embed_name : str, optional embed_name : str, optional
Embed name Embed name
""" """
def __init__(self,
dev_id, def __init__(
g, self,
embed_size, dev_id,
sparse_emb=False, g,
dgl_sparse_emb=False, embed_size,
feat_name='feat', sparse_emb=False,
embed_name='node_emb'): dgl_sparse_emb=False,
feat_name="feat",
embed_name="node_emb",
):
super(DistEmbedLayer, self).__init__() super(DistEmbedLayer, self).__init__()
self.dev_id = dev_id self.dev_id = dev_id
self.embed_size = embed_size self.embed_size = embed_size
...@@ -249,14 +300,16 @@ class DistEmbedLayer(nn.Module): ...@@ -249,14 +300,16 @@ class DistEmbedLayer(nn.Module):
self.feat_name = feat_name self.feat_name = feat_name
self.sparse_emb = sparse_emb self.sparse_emb = sparse_emb
self.g = g self.g = g
self.ntype_id_map = {g.get_ntype_id(ntype):ntype for ntype in g.ntypes} self.ntype_id_map = {g.get_ntype_id(ntype): ntype for ntype in g.ntypes}
self.node_projs = nn.ModuleDict() self.node_projs = nn.ModuleDict()
for ntype in g.ntypes: for ntype in g.ntypes:
if feat_name in g.nodes[ntype].data: if feat_name in g.nodes[ntype].data:
self.node_projs[ntype] = nn.Linear(g.nodes[ntype].data[feat_name].shape[1], embed_size) self.node_projs[ntype] = nn.Linear(
g.nodes[ntype].data[feat_name].shape[1], embed_size
)
nn.init.xavier_uniform_(self.node_projs[ntype].weight) nn.init.xavier_uniform_(self.node_projs[ntype].weight)
print('node {} has data {}'.format(ntype, feat_name)) print("node {} has data {}".format(ntype, feat_name))
if sparse_emb: if sparse_emb:
if dgl_sparse_emb: if dgl_sparse_emb:
self.node_embeds = {} self.node_embeds = {}
...@@ -264,24 +317,34 @@ class DistEmbedLayer(nn.Module): ...@@ -264,24 +317,34 @@ class DistEmbedLayer(nn.Module):
# We only create embeddings for nodes without node features. # We only create embeddings for nodes without node features.
if feat_name not in g.nodes[ntype].data: if feat_name not in g.nodes[ntype].data:
part_policy = g.get_node_partition_policy(ntype) part_policy = g.get_node_partition_policy(ntype)
self.node_embeds[ntype] = dgl.distributed.DistEmbedding(g.number_of_nodes(ntype), self.node_embeds[ntype] = dgl.distributed.DistEmbedding(
self.embed_size, g.number_of_nodes(ntype),
embed_name + '_' + ntype, self.embed_size,
init_emb, embed_name + "_" + ntype,
part_policy) init_emb,
part_policy,
)
else: else:
self.node_embeds = nn.ModuleDict() self.node_embeds = nn.ModuleDict()
for ntype in g.ntypes: for ntype in g.ntypes:
# We only create embeddings for nodes without node features. # We only create embeddings for nodes without node features.
if feat_name not in g.nodes[ntype].data: if feat_name not in g.nodes[ntype].data:
self.node_embeds[ntype] = th.nn.Embedding(g.number_of_nodes(ntype), self.embed_size, sparse=self.sparse_emb) self.node_embeds[ntype] = th.nn.Embedding(
nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0) g.number_of_nodes(ntype),
self.embed_size,
sparse=self.sparse_emb,
)
nn.init.uniform_(
self.node_embeds[ntype].weight, -1.0, 1.0
)
else: else:
self.node_embeds = nn.ModuleDict() self.node_embeds = nn.ModuleDict()
for ntype in g.ntypes: for ntype in g.ntypes:
# We only create embeddings for nodes without node features. # We only create embeddings for nodes without node features.
if feat_name not in g.nodes[ntype].data: if feat_name not in g.nodes[ntype].data:
self.node_embeds[ntype] = th.nn.Embedding(g.number_of_nodes(ntype), self.embed_size) self.node_embeds[ntype] = th.nn.Embedding(
g.number_of_nodes(ntype), self.embed_size
)
nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0) nn.init.uniform_(self.node_embeds[ntype].weight, -1.0, 1.0)
def forward(self, node_ids): def forward(self, node_ids):
...@@ -298,11 +361,18 @@ class DistEmbedLayer(nn.Module): ...@@ -298,11 +361,18 @@ class DistEmbedLayer(nn.Module):
embeds = {} embeds = {}
for ntype in node_ids: for ntype in node_ids:
if self.feat_name in self.g.nodes[ntype].data: if self.feat_name in self.g.nodes[ntype].data:
embeds[ntype] = self.node_projs[ntype](self.g.nodes[ntype].data[self.feat_name][node_ids[ntype]].to(self.dev_id)) embeds[ntype] = self.node_projs[ntype](
self.g.nodes[ntype]
.data[self.feat_name][node_ids[ntype]]
.to(self.dev_id)
)
else: else:
embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(self.dev_id) embeds[ntype] = self.node_embeds[ntype](node_ids[ntype]).to(
self.dev_id
)
return embeds return embeds
def compute_acc(results, labels): def compute_acc(results, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -310,25 +380,37 @@ def compute_acc(results, labels): ...@@ -310,25 +380,37 @@ def compute_acc(results, labels):
labels = labels.long() labels = labels.long()
return (results == labels).float().sum() / len(results) return (results == labels).float().sum() / len(results)
def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_nid, all_test_nid):
def evaluate(
g,
model,
embed_layer,
labels,
eval_loader,
test_loader,
all_val_nid,
all_test_nid,
):
model.eval() model.eval()
embed_layer.eval() embed_layer.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
global_results = dgl.distributed.DistTensor(labels.shape, th.long, 'results', persistent=True) global_results = dgl.distributed.DistTensor(
labels.shape, th.long, "results", persistent=True
)
with th.no_grad(): with th.no_grad():
th.cuda.empty_cache() th.cuda.empty_cache()
for sample_data in tqdm.tqdm(eval_loader): for sample_data in tqdm.tqdm(eval_loader):
input_nodes, seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
seeds = seeds['paper'] seeds = seeds["paper"]
feats = embed_layer(input_nodes) feats = embed_layer(input_nodes)
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1 assert len(logits) == 1
logits = logits['paper'] logits = logits["paper"]
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
assert np.all(seeds.numpy() < g.number_of_nodes('paper')) assert np.all(seeds.numpy() < g.number_of_nodes("paper"))
eval_seeds.append(seeds.cpu().detach()) eval_seeds.append(seeds.cpu().detach())
eval_logits = th.cat(eval_logits) eval_logits = th.cat(eval_logits)
eval_seeds = th.cat(eval_seeds) eval_seeds = th.cat(eval_seeds)
...@@ -340,13 +422,13 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni ...@@ -340,13 +422,13 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni
th.cuda.empty_cache() th.cuda.empty_cache()
for sample_data in tqdm.tqdm(test_loader): for sample_data in tqdm.tqdm(test_loader):
input_nodes, seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
seeds = seeds['paper'] seeds = seeds["paper"]
feats = embed_layer(input_nodes) feats = embed_layer(input_nodes)
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1 assert len(logits) == 1
logits = logits['paper'] logits = logits["paper"]
test_logits.append(logits.cpu().detach()) test_logits.append(logits.cpu().detach())
assert np.all(seeds.numpy() < g.number_of_nodes('paper')) assert np.all(seeds.numpy() < g.number_of_nodes("paper"))
test_seeds.append(seeds.cpu().detach()) test_seeds.append(seeds.cpu().detach())
test_logits = th.cat(test_logits) test_logits = th.cat(test_logits)
test_seeds = th.cat(test_seeds) test_seeds = th.cat(test_seeds)
...@@ -354,60 +436,78 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni ...@@ -354,60 +436,78 @@ def evaluate(g, model, embed_layer, labels, eval_loader, test_loader, all_val_ni
g.barrier() g.barrier()
if g.rank() == 0: if g.rank() == 0:
return compute_acc(global_results[all_val_nid], labels[all_val_nid]), \ return compute_acc(
compute_acc(global_results[all_test_nid], labels[all_test_nid]) global_results[all_val_nid], labels[all_val_nid]
), compute_acc(global_results[all_test_nid], labels[all_test_nid])
else: else:
return -1, -1 return -1, -1
def run(args, device, data):
g, num_classes, train_nid, val_nid, test_nid, labels, all_val_nid, all_test_nid = data
fanouts = [int(fanout) for fanout in args.fanout.split(',')] def run(args, device, data):
val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(',')] (
g,
num_classes,
train_nid,
val_nid,
test_nid,
labels,
all_val_nid,
all_test_nid,
) = data
fanouts = [int(fanout) for fanout in args.fanout.split(",")]
val_fanouts = [int(fanout) for fanout in args.validation_fanout.split(",")]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
dataloader = dgl.dataloading.DistNodeDataLoader( dataloader = dgl.dataloading.DistNodeDataLoader(
g, g,
{'paper': train_nid}, {"paper": train_nid},
sampler, sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
)
valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts) valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
valid_dataloader = dgl.dataloading.DistNodeDataLoader( valid_dataloader = dgl.dataloading.DistNodeDataLoader(
g, g,
{'paper': val_nid}, {"paper": val_nid},
valid_sampler, valid_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
)
test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts) test_sampler = dgl.dataloading.MultiLayerNeighborSampler(val_fanouts)
test_dataloader = dgl.dataloading.DistNodeDataLoader( test_dataloader = dgl.dataloading.DistNodeDataLoader(
g, g,
{'paper': test_nid}, {"paper": test_nid},
test_sampler, test_sampler,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
)
embed_layer = DistEmbedLayer(device,
g, embed_layer = DistEmbedLayer(
args.n_hidden, device,
sparse_emb=args.sparse_embedding, g,
dgl_sparse_emb=args.dgl_sparse, args.n_hidden,
feat_name='feat') sparse_emb=args.sparse_embedding,
dgl_sparse_emb=args.dgl_sparse,
model = EntityClassify(device, feat_name="feat",
args.n_hidden, )
num_classes,
g.etypes, model = EntityClassify(
num_bases=args.n_bases, device,
num_hidden_layers=args.n_layers-2, args.n_hidden,
dropout=args.dropout, num_classes,
use_self_loop=args.use_self_loop, g.etypes,
layer_norm=args.layer_norm) num_bases=args.n_bases,
num_hidden_layers=args.n_layers - 2,
dropout=args.dropout,
use_self_loop=args.use_self_loop,
layer_norm=args.layer_norm,
)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
...@@ -419,38 +519,63 @@ def run(args, device, data): ...@@ -419,38 +519,63 @@ def run(args, device, data):
embed_layer = DistributedDataParallel(embed_layer) embed_layer = DistributedDataParallel(embed_layer)
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
# If there are dense parameters in the embedding layer # If there are dense parameters in the embedding layer
# or we use Pytorch saprse embeddings. # or we use Pytorch saprse embeddings.
if len(embed_layer.node_projs) > 0 or not args.dgl_sparse: if len(embed_layer.node_projs) > 0 or not args.dgl_sparse:
embed_layer = embed_layer.to(device) embed_layer = embed_layer.to(device)
embed_layer = DistributedDataParallel(embed_layer, device_ids=[dev_id], output_device=dev_id) embed_layer = DistributedDataParallel(
embed_layer, device_ids=[dev_id], output_device=dev_id
)
if args.sparse_embedding: if args.sparse_embedding:
if args.dgl_sparse and args.standalone: if args.dgl_sparse and args.standalone:
emb_optimizer = dgl.distributed.optim.SparseAdam(list(embed_layer.node_embeds.values()), lr=args.sparse_lr) emb_optimizer = dgl.distributed.optim.SparseAdam(
print('optimize DGL sparse embedding:', embed_layer.node_embeds.keys()) list(embed_layer.node_embeds.values()), lr=args.sparse_lr
)
print(
"optimize DGL sparse embedding:", embed_layer.node_embeds.keys()
)
elif args.dgl_sparse: elif args.dgl_sparse:
emb_optimizer = dgl.distributed.optim.SparseAdam(list(embed_layer.module.node_embeds.values()), lr=args.sparse_lr) emb_optimizer = dgl.distributed.optim.SparseAdam(
print('optimize DGL sparse embedding:', embed_layer.module.node_embeds.keys()) list(embed_layer.module.node_embeds.values()), lr=args.sparse_lr
)
print(
"optimize DGL sparse embedding:",
embed_layer.module.node_embeds.keys(),
)
elif args.standalone: elif args.standalone:
emb_optimizer = th.optim.SparseAdam(list(embed_layer.node_embeds.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', embed_layer.node_embeds) list(embed_layer.node_embeds.parameters()), lr=args.sparse_lr
)
print("optimize Pytorch sparse embedding:", embed_layer.node_embeds)
else: else:
emb_optimizer = th.optim.SparseAdam(list(embed_layer.module.node_embeds.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', embed_layer.module.node_embeds) list(embed_layer.module.node_embeds.parameters()),
lr=args.sparse_lr,
)
print(
"optimize Pytorch sparse embedding:",
embed_layer.module.node_embeds,
)
dense_params = list(model.parameters()) dense_params = list(model.parameters())
if args.standalone: if args.standalone:
dense_params += list(embed_layer.node_projs.parameters()) dense_params += list(embed_layer.node_projs.parameters())
print('optimize dense projection:', embed_layer.node_projs) print("optimize dense projection:", embed_layer.node_projs)
else: else:
dense_params += list(embed_layer.module.node_projs.parameters()) dense_params += list(embed_layer.module.node_projs.parameters())
print('optimize dense projection:', embed_layer.module.node_projs) print("optimize dense projection:", embed_layer.module.node_projs)
optimizer = th.optim.Adam(dense_params, lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(
dense_params, lr=args.lr, weight_decay=args.l2norm
)
else: else:
all_params = list(model.parameters()) + list(embed_layer.parameters()) all_params = list(model.parameters()) + list(embed_layer.parameters())
optimizer = th.optim.Adam(all_params, lr=args.lr, weight_decay=args.l2norm) optimizer = th.optim.Adam(
all_params, lr=args.lr, weight_decay=args.l2norm
)
# training loop # training loop
print("start training...") print("start training...")
...@@ -480,9 +605,11 @@ def run(args, device, data): ...@@ -480,9 +605,11 @@ def run(args, device, data):
step_time = [] step_time = []
for step, sample_data in enumerate(dataloader): for step, sample_data in enumerate(dataloader):
input_nodes, seeds, blocks = sample_data input_nodes, seeds, blocks = sample_data
seeds = seeds['paper'] seeds = seeds["paper"]
number_train += seeds.shape[0] number_train += seeds.shape[0]
number_input += np.sum([blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]) number_input += np.sum(
[blocks[0].num_src_nodes(ntype) for ntype in blocks[0].ntypes]
)
tic_step = time.time() tic_step = time.time()
sample_time += tic_step - start sample_time += tic_step - start
sample_t.append(tic_step - start) sample_t.append(tic_step - start)
...@@ -495,7 +622,7 @@ def run(args, device, data): ...@@ -495,7 +622,7 @@ def run(args, device, data):
# forward # forward
logits = model(blocks, feats) logits = model(blocks, feats)
assert len(logits) == 1 assert len(logits) == 1
logits = logits['paper'] logits = logits["paper"]
loss = F.cross_entropy(logits, label) loss = F.cross_entropy(logits, label)
forward_end = time.time() forward_end = time.time()
...@@ -516,125 +643,276 @@ def run(args, device, data): ...@@ -516,125 +643,276 @@ def run(args, device, data):
step_t = time.time() - start step_t = time.time() - start
step_time.append(step_t) step_time.append(step_t)
train_acc = th.sum(logits.argmax(dim=1) == label).item() / len(seeds) train_acc = th.sum(logits.argmax(dim=1) == label).item() / len(
seeds
)
if step % args.log_every == 0: if step % args.log_every == 0:
print('[{}] Epoch {:05d} | Step {:05d} | Train acc {:.4f} | Loss {:.4f} | time {:.3f} s' \ print(
'| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}'.format( "[{}] Epoch {:05d} | Step {:05d} | Train acc {:.4f} | Loss {:.4f} | time {:.3f} s"
g.rank(), epoch, step, train_acc, loss.item(), np.sum(step_time[-args.log_every:]), "| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}".format(
np.sum(sample_t[-args.log_every:]), np.sum(feat_copy_t[-args.log_every:]), np.sum(forward_t[-args.log_every:]), g.rank(),
np.sum(backward_t[-args.log_every:]), np.sum(update_t[-args.log_every:]))) epoch,
step,
train_acc,
loss.item(),
np.sum(step_time[-args.log_every :]),
np.sum(sample_t[-args.log_every :]),
np.sum(feat_copy_t[-args.log_every :]),
np.sum(forward_t[-args.log_every :]),
np.sum(backward_t[-args.log_every :]),
np.sum(update_t[-args.log_every :]),
)
)
start = time.time() start = time.time()
gc.collect() gc.collect()
print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}'.format( print(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), number_train, number_input)) "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #train: {}, #input: {}".format(
g.rank(),
np.sum(step_time),
np.sum(sample_t),
np.sum(feat_copy_t),
np.sum(forward_t),
np.sum(backward_t),
np.sum(update_t),
number_train,
number_input,
)
)
epoch += 1 epoch += 1
start = time.time() start = time.time()
g.barrier() g.barrier()
val_acc, test_acc = evaluate(g, model, embed_layer, labels, val_acc, test_acc = evaluate(
valid_dataloader, test_dataloader, all_val_nid, all_test_nid) g,
model,
embed_layer,
labels,
valid_dataloader,
test_dataloader,
all_val_nid,
all_test_nid,
)
if val_acc >= 0: if val_acc >= 0:
print('Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(val_acc, test_acc, print(
time.time() - start)) "Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
val_acc, test_acc, time.time() - start
)
)
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend="gloo")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.conf_path)
print('rank:', g.rank()) print("rank:", g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
if 'trainer_id' in g.nodes['paper'].data: if "trainer_id" in g.nodes["paper"].data:
train_nid = dgl.distributed.node_split(g.nodes['paper'].data['train_mask'], train_nid = dgl.distributed.node_split(
pb, ntype='paper', force_even=True, g.nodes["paper"].data["train_mask"],
node_trainer_ids=g.nodes['paper'].data['trainer_id']) pb,
val_nid = dgl.distributed.node_split(g.nodes['paper'].data['val_mask'], ntype="paper",
pb, ntype='paper', force_even=True, force_even=True,
node_trainer_ids=g.nodes['paper'].data['trainer_id']) node_trainer_ids=g.nodes["paper"].data["trainer_id"],
test_nid = dgl.distributed.node_split(g.nodes['paper'].data['test_mask'], )
pb, ntype='paper', force_even=True, val_nid = dgl.distributed.node_split(
node_trainer_ids=g.nodes['paper'].data['trainer_id']) g.nodes["paper"].data["val_mask"],
pb,
ntype="paper",
force_even=True,
node_trainer_ids=g.nodes["paper"].data["trainer_id"],
)
test_nid = dgl.distributed.node_split(
g.nodes["paper"].data["test_mask"],
pb,
ntype="paper",
force_even=True,
node_trainer_ids=g.nodes["paper"].data["trainer_id"],
)
else: else:
train_nid = dgl.distributed.node_split(g.nodes['paper'].data['train_mask'], train_nid = dgl.distributed.node_split(
pb, ntype='paper', force_even=True) g.nodes["paper"].data["train_mask"],
val_nid = dgl.distributed.node_split(g.nodes['paper'].data['val_mask'], pb,
pb, ntype='paper', force_even=True) ntype="paper",
test_nid = dgl.distributed.node_split(g.nodes['paper'].data['test_mask'], force_even=True,
pb, ntype='paper', force_even=True) )
local_nid = pb.partid2nids(pb.partid, 'paper').detach().numpy() val_nid = dgl.distributed.node_split(
print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format( g.nodes["paper"].data["val_mask"],
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)), pb,
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)), ntype="paper",
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid)))) force_even=True,
)
test_nid = dgl.distributed.node_split(
g.nodes["paper"].data["test_mask"],
pb,
ntype="paper",
force_even=True,
)
local_nid = pb.partid2nids(pb.partid, "paper").detach().numpy()
print(
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format(
g.rank(),
len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
)
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device("cpu")
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id)) device = th.device("cuda:" + str(dev_id))
labels = g.nodes['paper'].data['labels'][np.arange(g.number_of_nodes('paper'))] labels = g.nodes["paper"].data["labels"][
all_val_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['val_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() np.arange(g.number_of_nodes("paper"))
all_test_nid = th.LongTensor(np.nonzero(g.nodes['paper'].data['test_mask'][np.arange(g.number_of_nodes('paper'))])).squeeze() ]
all_val_nid = th.LongTensor(
np.nonzero(
g.nodes["paper"].data["val_mask"][
np.arange(g.number_of_nodes("paper"))
]
)
).squeeze()
all_test_nid = th.LongTensor(
np.nonzero(
g.nodes["paper"].data["test_mask"][
np.arange(g.number_of_nodes("paper"))
]
)
).squeeze()
n_classes = len(th.unique(labels[labels >= 0])) n_classes = len(th.unique(labels[labels >= 0]))
print('#classes:', n_classes) print("#classes:", n_classes)
run(args, device, (g, n_classes, train_nid, val_nid, test_nid, labels, all_val_nid, all_test_nid)) run(
args,
if __name__ == '__main__': device,
parser = argparse.ArgumentParser(description='RGCN') (
g,
n_classes,
train_nid,
val_nid,
test_nid,
labels,
all_val_nid,
all_test_nid,
),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="RGCN")
# distributed training related # distributed training related
parser.add_argument('--graph-name', type=str, help='graph name') parser.add_argument("--graph-name", type=str, help="graph name")
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument("--id", type=int, help="the partition id")
parser.add_argument('--ip-config', type=str, help='The file for IP configuration') parser.add_argument(
parser.add_argument('--conf-path', type=str, help='The path to the partition config file') "--ip-config", type=str, help="The file for IP configuration"
)
parser.add_argument(
"--conf-path", type=str, help="The path to the partition config file"
)
# rgcn related # rgcn related
parser.add_argument('--num_gpus', type=int, default=-1, parser.add_argument(
help="the number of GPU device. Use -1 for CPU training") "--num_gpus",
parser.add_argument("--dropout", type=float, default=0, type=int,
help="dropout probability") default=-1,
parser.add_argument("--n-hidden", type=int, default=16, help="the number of GPU device. Use -1 for CPU training",
help="number of hidden units") )
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument(
help="learning rate") "--dropout", type=float, default=0, help="dropout probability"
parser.add_argument("--sparse-lr", type=float, default=1e-2, )
help="sparse lr rate") parser.add_argument(
parser.add_argument("--n-bases", type=int, default=-1, "--n-hidden", type=int, default=16, help="number of hidden units"
help="number of filter weight matrices, default: -1 [use all]") )
parser.add_argument("--n-layers", type=int, default=2, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="number of propagation rounds") parser.add_argument(
parser.add_argument("-e", "--n-epochs", type=int, default=50, "--sparse-lr", type=float, default=1e-2, help="sparse lr rate"
help="number of training epochs") )
parser.add_argument("-d", "--dataset", type=str, required=True, parser.add_argument(
help="dataset to use") "--n-bases",
parser.add_argument("--l2norm", type=float, default=0, type=int,
help="l2 norm coef") default=-1,
parser.add_argument("--relabel", default=False, action='store_true', help="number of filter weight matrices, default: -1 [use all]",
help="remove untouched nodes and relabel") )
parser.add_argument("--fanout", type=str, default="4, 4", parser.add_argument(
help="Fan-out of neighbor sampling.") "--n-layers", type=int, default=2, help="number of propagation rounds"
parser.add_argument("--validation-fanout", type=str, default=None, )
help="Fan-out of neighbor sampling during validation.") parser.add_argument(
parser.add_argument("--use-self-loop", default=False, action='store_true', "-e",
help="include self feature as a special relation") "--n-epochs",
parser.add_argument("--batch-size", type=int, default=100, type=int,
help="Mini-batch size. ") default=50,
parser.add_argument("--eval-batch-size", type=int, default=128, help="number of training epochs",
help="Mini-batch size. ") )
parser.add_argument('--log-every', type=int, default=20) parser.add_argument(
parser.add_argument("--low-mem", default=False, action='store_true', "-d", "--dataset", type=str, required=True, help="dataset to use"
help="Whether use low mem RelGraphCov") )
parser.add_argument("--sparse-embedding", action='store_true', parser.add_argument("--l2norm", type=float, default=0, help="l2 norm coef")
help='Use sparse embedding for node embeddings.') parser.add_argument(
parser.add_argument("--dgl-sparse", action='store_true', "--relabel",
help='Whether to use DGL sparse embedding') default=False,
parser.add_argument('--layer-norm', default=False, action='store_true', action="store_true",
help='Use layer norm') help="remove untouched nodes and relabel",
parser.add_argument('--local_rank', type=int, help='get rank of the process') )
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument(
"--fanout",
type=str,
default="4, 4",
help="Fan-out of neighbor sampling.",
)
parser.add_argument(
"--validation-fanout",
type=str,
default=None,
help="Fan-out of neighbor sampling during validation.",
)
parser.add_argument(
"--use-self-loop",
default=False,
action="store_true",
help="include self feature as a special relation",
)
parser.add_argument(
"--batch-size", type=int, default=100, help="Mini-batch size. "
)
parser.add_argument(
"--eval-batch-size", type=int, default=128, help="Mini-batch size. "
)
parser.add_argument("--log-every", type=int, default=20)
parser.add_argument(
"--low-mem",
default=False,
action="store_true",
help="Whether use low mem RelGraphCov",
)
parser.add_argument(
"--sparse-embedding",
action="store_true",
help="Use sparse embedding for node embeddings.",
)
parser.add_argument(
"--dgl-sparse",
action="store_true",
help="Whether to use DGL sparse embedding",
)
parser.add_argument(
"--layer-norm",
default=False,
action="store_true",
help="Use layer norm",
)
parser.add_argument(
"--local_rank", type=int, help="get rank of the process"
)
parser.add_argument(
"--standalone", action="store_true", help="run in the standalone mode"
)
args = parser.parse_args() args = parser.parse_args()
# if validation_fanout is None, set it with args.fanout # if validation_fanout is None, set it with args.fanout
......
import dgl
import json import json
import torch as th
import dgl
import numpy as np import numpy as np
import torch as th
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
# Load OGB-MAG. # Load OGB-MAG.
dataset = DglNodePropPredDataset(name='ogbn-mag') dataset = DglNodePropPredDataset(name="ogbn-mag")
hg_orig, labels = dataset[0] hg_orig, labels = dataset[0]
subgs = {} subgs = {}
for etype in hg_orig.canonical_etypes: for etype in hg_orig.canonical_etypes:
u, v = hg_orig.all_edges(etype=etype) u, v = hg_orig.all_edges(etype=etype)
subgs[etype] = (u, v) subgs[etype] = (u, v)
subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u) subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs) hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat'] hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]['paper'] train_idx = split_idx["train"]["paper"]
val_idx = split_idx["valid"]['paper'] val_idx = split_idx["valid"]["paper"]
test_idx = split_idx["test"]['paper'] test_idx = split_idx["test"]["paper"]
paper_labels = labels['paper'].squeeze() paper_labels = labels["paper"].squeeze()
train_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) train_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
train_mask[train_idx] = True train_mask[train_idx] = True
val_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) val_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
val_mask[val_idx] = True val_mask[val_idx] = True
test_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) test_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
test_mask[test_idx] = True test_mask[test_idx] = True
hg.nodes['paper'].data['train_mask'] = train_mask hg.nodes["paper"].data["train_mask"] = train_mask
hg.nodes['paper'].data['val_mask'] = val_mask hg.nodes["paper"].data["val_mask"] = val_mask
hg.nodes['paper'].data['test_mask'] = test_mask hg.nodes["paper"].data["test_mask"] = test_mask
hg.nodes['paper'].data['labels'] = paper_labels hg.nodes["paper"].data["labels"] = paper_labels
with open('outputs/mag.json') as json_file: with open("outputs/mag.json") as json_file:
metadata = json.load(json_file) metadata = json.load(json_file)
for part_id in range(metadata['num_parts']): for part_id in range(metadata["num_parts"]):
subg = dgl.load_graphs('outputs/part{}/graph.dgl'.format(part_id))[0][0] subg = dgl.load_graphs("outputs/part{}/graph.dgl".format(part_id))[0][0]
node_data = {} node_data = {}
for ntype in hg.ntypes: for ntype in hg.ntypes:
local_node_idx = th.logical_and(subg.ndata['inner_node'].bool(), local_node_idx = th.logical_and(
subg.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype)) subg.ndata["inner_node"].bool(),
local_nodes = subg.ndata['orig_id'][local_node_idx].numpy() subg.ndata[dgl.NTYPE] == hg.get_ntype_id(ntype),
)
local_nodes = subg.ndata["orig_id"][local_node_idx].numpy()
for name in hg.nodes[ntype].data: for name in hg.nodes[ntype].data:
node_data[ntype + '/' + name] = hg.nodes[ntype].data[name][local_nodes] node_data[ntype + "/" + name] = hg.nodes[ntype].data[name][
print('node features:', node_data.keys()) local_nodes
dgl.data.utils.save_tensors('outputs/' + metadata['part-{}'.format(part_id)]['node_feats'], node_data) ]
print("node features:", node_data.keys())
dgl.data.utils.save_tensors(
"outputs/" + metadata["part-{}".format(part_id)]["node_feats"],
node_data,
)
edge_data = {} edge_data = {}
for etype in hg.etypes: for etype in hg.etypes:
local_edges = subg.edata['orig_id'][subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)] local_edges = subg.edata["orig_id"][
subg.edata[dgl.ETYPE] == hg.get_etype_id(etype)
]
for name in hg.edges[etype].data: for name in hg.edges[etype].data:
edge_data[etype + '/' + name] = hg.edges[etype].data[name][local_edges] edge_data[etype + "/" + name] = hg.edges[etype].data[name][
print('edge features:', edge_data.keys()) local_edges
dgl.data.utils.save_tensors('outputs/' + metadata['part-{}'.format(part_id)]['edge_feats'], edge_data) ]
print("edge features:", edge_data.keys())
dgl.data.utils.save_tensors(
"outputs/" + metadata["part-{}".format(part_id)]["edge_feats"],
edge_data,
)
import argparse
import time
import dgl import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import argparse
import time
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
def load_ogb(dataset): def load_ogb(dataset):
if dataset == 'ogbn-mag': if dataset == "ogbn-mag":
dataset = DglNodePropPredDataset(name=dataset) dataset = DglNodePropPredDataset(name=dataset)
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx = split_idx["train"]['paper'] train_idx = split_idx["train"]["paper"]
val_idx = split_idx["valid"]['paper'] val_idx = split_idx["valid"]["paper"]
test_idx = split_idx["test"]['paper'] test_idx = split_idx["test"]["paper"]
hg_orig, labels = dataset[0] hg_orig, labels = dataset[0]
subgs = {} subgs = {}
for etype in hg_orig.canonical_etypes: for etype in hg_orig.canonical_etypes:
u, v = hg_orig.all_edges(etype=etype) u, v = hg_orig.all_edges(etype=etype)
subgs[etype] = (u, v) subgs[etype] = (u, v)
subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u) subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs) hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat'] hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
paper_labels = labels['paper'].squeeze() paper_labels = labels["paper"].squeeze()
num_rels = len(hg.canonical_etypes) num_rels = len(hg.canonical_etypes)
num_of_ntype = len(hg.ntypes) num_of_ntype = len(hg.ntypes)
num_classes = dataset.num_classes num_classes = dataset.num_classes
category = 'paper' category = "paper"
print('Number of relations: {}'.format(num_rels)) print("Number of relations: {}".format(num_rels))
print('Number of class: {}'.format(num_classes)) print("Number of class: {}".format(num_classes))
print('Number of train: {}'.format(len(train_idx))) print("Number of train: {}".format(len(train_idx)))
print('Number of valid: {}'.format(len(val_idx))) print("Number of valid: {}".format(len(val_idx)))
print('Number of test: {}'.format(len(test_idx))) print("Number of test: {}".format(len(test_idx)))
# get target category id # get target category id
category_id = len(hg.ntypes) category_id = len(hg.ntypes)
...@@ -39,58 +41,90 @@ def load_ogb(dataset): ...@@ -39,58 +41,90 @@ def load_ogb(dataset):
if ntype == category: if ntype == category:
category_id = i category_id = i
train_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) train_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
train_mask[train_idx] = True train_mask[train_idx] = True
val_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) val_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
val_mask[val_idx] = True val_mask[val_idx] = True
test_mask = th.zeros((hg.number_of_nodes('paper'),), dtype=th.bool) test_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
test_mask[test_idx] = True test_mask[test_idx] = True
hg.nodes['paper'].data['train_mask'] = train_mask hg.nodes["paper"].data["train_mask"] = train_mask
hg.nodes['paper'].data['val_mask'] = val_mask hg.nodes["paper"].data["val_mask"] = val_mask
hg.nodes['paper'].data['test_mask'] = test_mask hg.nodes["paper"].data["test_mask"] = test_mask
hg.nodes['paper'].data['labels'] = paper_labels hg.nodes["paper"].data["labels"] = paper_labels
return hg return hg
else: else:
raise("Do not support other ogbn datasets.") raise ("Do not support other ogbn datasets.")
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser("Partition builtin graphs") argparser = argparse.ArgumentParser("Partition builtin graphs")
argparser.add_argument('--dataset', type=str, default='ogbn-mag', argparser.add_argument(
help='datasets: ogbn-mag') "--dataset", type=str, default="ogbn-mag", help="datasets: ogbn-mag"
argparser.add_argument('--num_parts', type=int, default=4, )
help='number of partitions') argparser.add_argument(
argparser.add_argument('--part_method', type=str, default='metis', "--num_parts", type=int, default=4, help="number of partitions"
help='the partition method') )
argparser.add_argument('--balance_train', action='store_true', argparser.add_argument(
help='balance the training size in each partition.') "--part_method", type=str, default="metis", help="the partition method"
argparser.add_argument('--undirected', action='store_true', )
help='turn the graph into an undirected graph.') argparser.add_argument(
argparser.add_argument('--balance_edges', action='store_true', "--balance_train",
help='balance the number of edges in each partition.') action="store_true",
argparser.add_argument('--num_trainers_per_machine', type=int, default=1, help="balance the training size in each partition.",
help='the number of trainers per machine. The trainer ids are stored\ )
in the node feature \'trainer_id\'') argparser.add_argument(
argparser.add_argument('--output', type=str, default='data', "--undirected",
help='Output path of partitioned graph.') action="store_true",
help="turn the graph into an undirected graph.",
)
argparser.add_argument(
"--balance_edges",
action="store_true",
help="balance the number of edges in each partition.",
)
argparser.add_argument(
"--num_trainers_per_machine",
type=int,
default=1,
help="the number of trainers per machine. The trainer ids are stored\
in the node feature 'trainer_id'",
)
argparser.add_argument(
"--output",
type=str,
default="data",
help="Output path of partitioned graph.",
)
args = argparser.parse_args() args = argparser.parse_args()
start = time.time() start = time.time()
g = load_ogb(args.dataset) g = load_ogb(args.dataset)
print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start)) print(
print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges())) "load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
print('train: {}, valid: {}, test: {}'.format(th.sum(g.nodes['paper'].data['train_mask']), )
th.sum(g.nodes['paper'].data['val_mask']), print("|V|={}, |E|={}".format(g.number_of_nodes(), g.number_of_edges()))
th.sum(g.nodes['paper'].data['test_mask']))) print(
"train: {}, valid: {}, test: {}".format(
th.sum(g.nodes["paper"].data["train_mask"]),
th.sum(g.nodes["paper"].data["val_mask"]),
th.sum(g.nodes["paper"].data["test_mask"]),
)
)
if args.balance_train: if args.balance_train:
balance_ntypes = {'paper': g.nodes['paper'].data['train_mask']} balance_ntypes = {"paper": g.nodes["paper"].data["train_mask"]}
else: else:
balance_ntypes = None balance_ntypes = None
dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output, dgl.distributed.partition_graph(
part_method=args.part_method, g,
balance_ntypes=balance_ntypes, args.dataset,
balance_edges=args.balance_edges, args.num_parts,
num_trainers_per_machine=args.num_trainers_per_machine) args.output,
part_method=args.part_method,
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges,
num_trainers_per_machine=args.num_trainers_per_machine,
)
import pandas as pd import argparse
import os
import glob import glob
import json import json
import argparse import os
from collections import defaultdict from collections import defaultdict
import pandas as pd
path = os.getcwd() path = os.getcwd()
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", help="name of graph to create", default="order") parser.add_argument(
parser.add_argument("-nc", "--node_column", nargs="+", default=['order_id', 'entity_index', 'order_datetime', 'cid']) "-n", "--name", help="name of graph to create", default="order"
parser.add_argument("-nk", "--node_key", default='entity_index') )
parser.add_argument("-ec", "--edge_column", nargs="+", default=['predicate_type', 'predicate_index', 'entity_index', 'entity_index_y']) parser.add_argument(
"-nc",
"--node_column",
nargs="+",
default=["order_id", "entity_index", "order_datetime", "cid"],
)
parser.add_argument("-nk", "--node_key", default="entity_index")
parser.add_argument(
"-ec",
"--edge_column",
nargs="+",
default=[
"predicate_type",
"predicate_index",
"entity_index",
"entity_index_y",
],
)
parser.add_argument("-es", "--edge_start", default="entity_index") parser.add_argument("-es", "--edge_start", default="entity_index")
parser.add_argument("-en", "--edge_end", default="entity_index_y") parser.add_argument("-en", "--edge_end", default="entity_index_y")
args = parser.parse_args() args = parser.parse_args()
#Store all types of node in nodes folder # Store all types of node in nodes folder
nodes_list = sorted(glob.glob(os.path.join(path, "nodes/*"))) nodes_list = sorted(glob.glob(os.path.join(path, "nodes/*")))
if os.path.exists("{}_nodes.txt".format(args.name)): if os.path.exists("{}_nodes.txt".format(args.name)):
...@@ -31,10 +49,16 @@ for node_type_name in nodes_list: ...@@ -31,10 +49,16 @@ for node_type_name in nodes_list:
nodes_count = 0 nodes_count = 0
csv_files = sorted(glob.glob(os.path.join(node_type_name, "*.csv"))) csv_files = sorted(glob.glob(os.path.join(node_type_name, "*.csv")))
for file_name in csv_files: for file_name in csv_files:
df = pd.read_csv(file_name, error_bad_lines=False, escapechar='\\', names=args.node_column, usecols=[*range(len(args.node_column))]) df = pd.read_csv(
file_name,
error_bad_lines=False,
escapechar="\\",
names=args.node_column,
usecols=[*range(len(args.node_column))],
)
df_entity = pd.DataFrame(df[args.node_key], columns=[args.node_key]) df_entity = pd.DataFrame(df[args.node_key], columns=[args.node_key])
df_entity['type'] = node_type_id df_entity["type"] = node_type_id
column_list = ['type'] column_list = ["type"]
for weight_index in range(len(nodes_list)): for weight_index in range(len(nodes_list)):
weight_num = "weight{}".format(weight_index) weight_num = "weight{}".format(weight_index)
column_list.append(weight_num) column_list.append(weight_num)
...@@ -44,10 +68,20 @@ for node_type_name in nodes_list: ...@@ -44,10 +68,20 @@ for node_type_name in nodes_list:
df_entity[weight_num] = 0 df_entity[weight_num] = 0
nodes_count += len(df_entity.index) nodes_count += len(df_entity.index)
column_list.append(args.node_key) column_list.append(args.node_key)
#This loop is trying to create file which servers as an input for Metis Algorithm. # This loop is trying to create file which servers as an input for Metis Algorithm.
#More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis # More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis
df_entity.to_csv("{}_nodes.txt".format(args.name), columns=column_list, sep=" ", index=False, header=False, mode='a') df_entity.to_csv(
schema_dict['nid'][os.path.basename(node_type_name)] = [all_nodes_count, nodes_count + all_nodes_count] "{}_nodes.txt".format(args.name),
columns=column_list,
sep=" ",
index=False,
header=False,
mode="a",
)
schema_dict["nid"][os.path.basename(node_type_name)] = [
all_nodes_count,
nodes_count + all_nodes_count,
]
all_nodes_count += nodes_count all_nodes_count += nodes_count
node_type_id += 1 node_type_id += 1
...@@ -55,7 +89,7 @@ for node_type_name in nodes_list: ...@@ -55,7 +89,7 @@ for node_type_name in nodes_list:
if os.path.exists("{}_edges.txt".format(args.name)): if os.path.exists("{}_edges.txt".format(args.name)):
os.remove("{}_edges.txt".format(args.name)) os.remove("{}_edges.txt".format(args.name))
#Store all types of edge in edges folder # Store all types of edge in edges folder
edges_list = sorted(glob.glob(os.path.join(path, "edges/*"))) edges_list = sorted(glob.glob(os.path.join(path, "edges/*")))
...@@ -65,16 +99,35 @@ for edge_type_name in edges_list: ...@@ -65,16 +99,35 @@ for edge_type_name in edges_list:
edge_count = 0 edge_count = 0
csv_files = sorted(glob.glob(os.path.join(edge_type_name, "*.csv"))) csv_files = sorted(glob.glob(os.path.join(edge_type_name, "*.csv")))
for file_name in csv_files: for file_name in csv_files:
df = pd.read_csv(file_name, error_bad_lines=False, escapechar='\\', names=args.edge_column, usecols=[*range(len(args.edge_column))]) df = pd.read_csv(
df_entity = pd.DataFrame(df[[args.edge_start, args.edge_end]], columns=[args.edge_start, args.edge_end]) file_name,
df_entity['type'] = edge_type_id error_bad_lines=False,
escapechar="\\",
names=args.edge_column,
usecols=[*range(len(args.edge_column))],
)
df_entity = pd.DataFrame(
df[[args.edge_start, args.edge_end]],
columns=[args.edge_start, args.edge_end],
)
df_entity["type"] = edge_type_id
df_entity = df_entity.reset_index() df_entity = df_entity.reset_index()
df_entity['number'] = df_entity.index + edge_count df_entity["number"] = df_entity.index + edge_count
edge_count += len(df_entity.index) edge_count += len(df_entity.index)
#This loop is trying to create file which servers as an input for Metis Algorithm. # This loop is trying to create file which servers as an input for Metis Algorithm.
#More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis # More details about metis input can been found here : https://docs.dgl.ai/en/0.6.x/guide/distributed-preprocessing.html#input-format-for-parmetis
df_entity.to_csv("{}_edges.txt".format(args.name), columns=[args.edge_start, args.edge_end, 'number', 'type'], sep=" ", index=False, header=False, mode='a') df_entity.to_csv(
schema_dict['eid'][os.path.basename(edge_type_name)] = [all_edges_count, all_edges_count + edge_count] "{}_edges.txt".format(args.name),
columns=[args.edge_start, args.edge_end, "number", "type"],
sep=" ",
index=False,
header=False,
mode="a",
)
schema_dict["eid"][os.path.basename(edge_type_name)] = [
all_edges_count,
all_edges_count + edge_count,
]
edge_type_id += 1 edge_type_id += 1
all_edges_count += edge_count all_edges_count += edge_count
...@@ -82,12 +135,20 @@ if os.path.exists("{}_stats.txt".format(args.name)): ...@@ -82,12 +135,20 @@ if os.path.exists("{}_stats.txt".format(args.name)):
os.remove("{}_stats.txt".format(args.name)) os.remove("{}_stats.txt".format(args.name))
df = pd.DataFrame([[all_nodes_count, all_edges_count, len(nodes_list)]], columns=['nodes_count', 'edges_count', 'weight_count']) df = pd.DataFrame(
df.to_csv("{}_stats.txt".format(args.name), columns=['nodes_count', 'edges_count', 'weight_count'], sep=" ", index=False, header=False) [[all_nodes_count, all_edges_count, len(nodes_list)]],
columns=["nodes_count", "edges_count", "weight_count"],
)
df.to_csv(
"{}_stats.txt".format(args.name),
columns=["nodes_count", "edges_count", "weight_count"],
sep=" ",
index=False,
header=False,
)
if os.path.exists("{}.json".format(args.name)): if os.path.exists("{}.json".format(args.name)):
os.remove("{}.json".format(args.name)) os.remove("{}.json".format(args.name))
with open("{}.json".format(args.name), "w", encoding="utf8") as json_file: with open("{}.json".format(args.name), "w", encoding="utf8") as json_file:
json.dump(schema_dict, json_file, ensure_ascii=False) json.dump(schema_dict, json_file, ensure_ascii=False)
import os
import json import json
import numpy as np import os
import dgl import dgl
import numpy as np
import torch as th import torch as th
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
partitions_folder = 'outputs' partitions_folder = "outputs"
graph_name = 'mag' graph_name = "mag"
with open('{}/{}.json'.format(partitions_folder, graph_name)) as json_file: with open("{}/{}.json".format(partitions_folder, graph_name)) as json_file:
metadata = json.load(json_file) metadata = json.load(json_file)
num_parts = metadata['num_parts'] num_parts = metadata["num_parts"]
# Load OGB-MAG. # Load OGB-MAG.
dataset = DglNodePropPredDataset(name='ogbn-mag') dataset = DglNodePropPredDataset(name="ogbn-mag")
hg_orig, labels = dataset[0] hg_orig, labels = dataset[0]
subgs = {} subgs = {}
for etype in hg_orig.canonical_etypes: for etype in hg_orig.canonical_etypes:
u, v = hg_orig.all_edges(etype=etype) u, v = hg_orig.all_edges(etype=etype)
subgs[etype] = (u, v) subgs[etype] = (u, v)
subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u) subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs) hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat'] hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
# Construct node data and edge data after reshuffling. # Construct node data and edge data after reshuffling.
node_feats = {} node_feats = {}
edge_feats = {} edge_feats = {}
for partid in range(num_parts): for partid in range(num_parts):
part_node_feats = dgl.data.utils.load_tensors( part_node_feats = dgl.data.utils.load_tensors(
'{}/part{}/node_feat.dgl'.format(partitions_folder, partid)) "{}/part{}/node_feat.dgl".format(partitions_folder, partid)
)
part_edge_feats = dgl.data.utils.load_tensors( part_edge_feats = dgl.data.utils.load_tensors(
'{}/part{}/edge_feat.dgl'.format(partitions_folder, partid)) "{}/part{}/edge_feat.dgl".format(partitions_folder, partid)
)
for key in part_node_feats: for key in part_node_feats:
if key in node_feats: if key in node_feats:
node_feats[key].append(part_node_feats[key]) node_feats[key].append(part_node_feats[key])
...@@ -45,43 +48,51 @@ for key in node_feats: ...@@ -45,43 +48,51 @@ for key in node_feats:
for key in edge_feats: for key in edge_feats:
edge_feats[key] = th.cat(edge_feats[key]) edge_feats[key] = th.cat(edge_feats[key])
ntype_map = metadata['ntypes'] ntype_map = metadata["ntypes"]
ntypes = [None] * len(ntype_map) ntypes = [None] * len(ntype_map)
for key in ntype_map: for key in ntype_map:
ntype_id = ntype_map[key] ntype_id = ntype_map[key]
ntypes[ntype_id] = key ntypes[ntype_id] = key
etype_map = metadata['etypes'] etype_map = metadata["etypes"]
etypes = [None] * len(etype_map) etypes = [None] * len(etype_map)
for key in etype_map: for key in etype_map:
etype_id = etype_map[key] etype_id = etype_map[key]
etypes[etype_id] = key etypes[etype_id] = key
etype2canonical = {etype: (srctype, etype, dsttype) etype2canonical = {
for srctype, etype, dsttype in hg.canonical_etypes} etype: (srctype, etype, dsttype)
for srctype, etype, dsttype in hg.canonical_etypes
}
node_map = metadata['node_map'] node_map = metadata["node_map"]
for key in node_map: for key in node_map:
node_map[key] = th.stack([th.tensor(row) for row in node_map[key]], 0) node_map[key] = th.stack([th.tensor(row) for row in node_map[key]], 0)
nid_map = dgl.distributed.id_map.IdMap(node_map) nid_map = dgl.distributed.id_map.IdMap(node_map)
edge_map = metadata['edge_map'] edge_map = metadata["edge_map"]
for key in edge_map: for key in edge_map:
edge_map[key] = th.stack([th.tensor(row) for row in edge_map[key]], 0) edge_map[key] = th.stack([th.tensor(row) for row in edge_map[key]], 0)
eid_map = dgl.distributed.id_map.IdMap(edge_map) eid_map = dgl.distributed.id_map.IdMap(edge_map)
for ntype in node_map: for ntype in node_map:
assert hg.number_of_nodes(ntype) == th.sum( assert hg.number_of_nodes(ntype) == th.sum(
node_map[ntype][:, 1] - node_map[ntype][:, 0]) node_map[ntype][:, 1] - node_map[ntype][:, 0]
)
for etype in edge_map: for etype in edge_map:
assert hg.number_of_edges(etype) == th.sum( assert hg.number_of_edges(etype) == th.sum(
edge_map[etype][:, 1] - edge_map[etype][:, 0]) edge_map[etype][:, 1] - edge_map[etype][:, 0]
)
# verify part_0 with graph_partition_book # verify part_0 with graph_partition_book
eid = [] eid = []
gpb = dgl.distributed.graph_partition_book.RangePartitionBook(0, num_parts, node_map, edge_map, gpb = dgl.distributed.graph_partition_book.RangePartitionBook(
{ntype: i for i, ntype in enumerate( 0,
hg.ntypes)}, num_parts,
{etype: i for i, etype in enumerate(hg.etypes)}) node_map,
subg0 = dgl.load_graphs('{}/part0/graph.dgl'.format(partitions_folder))[0][0] edge_map,
{ntype: i for i, ntype in enumerate(hg.ntypes)},
{etype: i for i, etype in enumerate(hg.etypes)},
)
subg0 = dgl.load_graphs("{}/part0/graph.dgl".format(partitions_folder))[0][0]
for etype in hg.etypes: for etype in hg.etypes:
type_eid = th.zeros((1,), dtype=th.int64) type_eid = th.zeros((1,), dtype=th.int64)
eid.append(gpb.map_to_homo_eid(type_eid, etype)) eid.append(gpb.map_to_homo_eid(type_eid, etype))
...@@ -96,8 +107,7 @@ gsrc, gdst = subg0.ndata[dgl.NID][lsrc], subg0.ndata[dgl.NID][ldst] ...@@ -96,8 +107,7 @@ gsrc, gdst = subg0.ndata[dgl.NID][lsrc], subg0.ndata[dgl.NID][ldst]
# The destination nodes are owned by the partition. # The destination nodes are owned by the partition.
assert th.all(gdst == ldst) assert th.all(gdst == ldst)
# gdst which is not assigned into current partition is not required to equal ldst # gdst which is not assigned into current partition is not required to equal ldst
assert th.all(th.logical_or( assert th.all(th.logical_or(gdst == ldst, subg0.ndata["inner_node"][ldst] == 0))
gdst == ldst, subg0.ndata['inner_node'][ldst] == 0))
etids, _ = gpb.map_to_per_etype(eid) etids, _ = gpb.map_to_per_etype(eid)
src_tids, _ = gpb.map_to_per_ntype(gsrc) src_tids, _ = gpb.map_to_per_ntype(gsrc)
dst_tids, _ = gpb.map_to_per_ntype(gdst) dst_tids, _ = gpb.map_to_per_ntype(gdst)
...@@ -105,7 +115,8 @@ canonical_etypes = [] ...@@ -105,7 +115,8 @@ canonical_etypes = []
etype_ids = th.arange(0, len(etypes)) etype_ids = th.arange(0, len(etypes))
for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids): for src_tid, etype_id, dst_tid in zip(src_tids, etype_ids, dst_tids):
canonical_etypes.append( canonical_etypes.append(
(ntypes[src_tid], etypes[etype_id], ntypes[dst_tid])) (ntypes[src_tid], etypes[etype_id], ntypes[dst_tid])
)
for etype in canonical_etypes: for etype in canonical_etypes:
assert etype in hg.canonical_etypes assert etype in hg.canonical_etypes
...@@ -113,12 +124,12 @@ for etype in canonical_etypes: ...@@ -113,12 +124,12 @@ for etype in canonical_etypes:
orig_node_ids = {ntype: [] for ntype in hg.ntypes} orig_node_ids = {ntype: [] for ntype in hg.ntypes}
orig_edge_ids = {etype: [] for etype in hg.etypes} orig_edge_ids = {etype: [] for etype in hg.etypes}
for partid in range(num_parts): for partid in range(num_parts):
print('test part', partid) print("test part", partid)
part_file = '{}/part{}/graph.dgl'.format(partitions_folder, partid) part_file = "{}/part{}/graph.dgl".format(partitions_folder, partid)
subg = dgl.load_graphs(part_file)[0][0] subg = dgl.load_graphs(part_file)[0][0]
subg_src_id, subg_dst_id = subg.edges() subg_src_id, subg_dst_id = subg.edges()
orig_src_id = subg.ndata['orig_id'][subg_src_id] orig_src_id = subg.ndata["orig_id"][subg_src_id]
orig_dst_id = subg.ndata['orig_id'][subg_dst_id] orig_dst_id = subg.ndata["orig_id"][subg_dst_id]
global_src_id = subg.ndata[dgl.NID][subg_src_id] global_src_id = subg.ndata[dgl.NID][subg_src_id]
global_dst_id = subg.ndata[dgl.NID][subg_dst_id] global_dst_id = subg.ndata[dgl.NID][subg_dst_id]
subg_ntype = subg.ndata[dgl.NTYPE] subg_ntype = subg.ndata[dgl.NTYPE]
...@@ -129,19 +140,23 @@ for partid in range(num_parts): ...@@ -129,19 +140,23 @@ for partid in range(num_parts):
# This is global IDs after reshuffle. # This is global IDs after reshuffle.
nid = subg.ndata[dgl.NID][idx] nid = subg.ndata[dgl.NID][idx]
ntype_ids1, type_nid = nid_map(nid) ntype_ids1, type_nid = nid_map(nid)
orig_type_nid = subg.ndata['orig_id'][idx] orig_type_nid = subg.ndata["orig_id"][idx]
inner_node = subg.ndata['inner_node'][idx] inner_node = subg.ndata["inner_node"][idx]
# All nodes should have the same node type. # All nodes should have the same node type.
assert np.all(ntype_ids1.numpy() == int(ntype_id)) assert np.all(ntype_ids1.numpy() == int(ntype_id))
assert np.all(nid[inner_node == 1].numpy() == np.arange( assert np.all(
node_map[ntype][partid, 0], node_map[ntype][partid, 1])) nid[inner_node == 1].numpy()
== np.arange(node_map[ntype][partid, 0], node_map[ntype][partid, 1])
)
orig_node_ids[ntype].append(orig_type_nid[inner_node == 1]) orig_node_ids[ntype].append(orig_type_nid[inner_node == 1])
# Check the degree of the inner nodes. # Check the degree of the inner nodes.
inner_nids = th.nonzero(th.logical_and(subg_ntype == ntype_id, subg.ndata['inner_node']), inner_nids = th.nonzero(
as_tuple=True)[0] th.logical_and(subg_ntype == ntype_id, subg.ndata["inner_node"]),
as_tuple=True,
)[0]
subg_deg = subg.in_degrees(inner_nids) subg_deg = subg.in_degrees(inner_nids)
orig_nids = subg.ndata['orig_id'][inner_nids] orig_nids = subg.ndata["orig_id"][inner_nids]
# Calculate the in-degrees of nodes of a particular node type. # Calculate the in-degrees of nodes of a particular node type.
glob_deg = th.zeros(len(subg_deg), dtype=th.int64) glob_deg = th.zeros(len(subg_deg), dtype=th.int64)
for etype in hg.canonical_etypes: for etype in hg.canonical_etypes:
...@@ -152,7 +167,7 @@ for partid in range(num_parts): ...@@ -152,7 +167,7 @@ for partid in range(num_parts):
# Check node data. # Check node data.
for name in hg.nodes[ntype].data: for name in hg.nodes[ntype].data:
local_data = node_feats[ntype + '/' + name][type_nid] local_data = node_feats[ntype + "/" + name][type_nid]
local_data1 = hg.nodes[ntype].data[name][orig_type_nid] local_data1 = hg.nodes[ntype].data[name][orig_type_nid]
assert np.all(local_data.numpy() == local_data1.numpy()) assert np.all(local_data.numpy() == local_data1.numpy())
...@@ -163,7 +178,7 @@ for partid in range(num_parts): ...@@ -163,7 +178,7 @@ for partid in range(num_parts):
exist = hg[etype].has_edges_between(orig_src_id[idx], orig_dst_id[idx]) exist = hg[etype].has_edges_between(orig_src_id[idx], orig_dst_id[idx])
assert np.all(exist.numpy()) assert np.all(exist.numpy())
eid = hg[etype].edge_ids(orig_src_id[idx], orig_dst_id[idx]) eid = hg[etype].edge_ids(orig_src_id[idx], orig_dst_id[idx])
assert np.all(eid.numpy() == subg.edata['orig_id'][idx].numpy()) assert np.all(eid.numpy() == subg.edata["orig_id"][idx].numpy())
ntype_ids, type_nid = nid_map(global_src_id[idx]) ntype_ids, type_nid = nid_map(global_src_id[idx])
assert len(th.unique(ntype_ids)) == 1 assert len(th.unique(ntype_ids)) == 1
...@@ -175,17 +190,19 @@ for partid in range(num_parts): ...@@ -175,17 +190,19 @@ for partid in range(num_parts):
# This is global IDs after reshuffle. # This is global IDs after reshuffle.
eid = subg.edata[dgl.EID][idx] eid = subg.edata[dgl.EID][idx]
etype_ids1, type_eid = eid_map(eid) etype_ids1, type_eid = eid_map(eid)
orig_type_eid = subg.edata['orig_id'][idx] orig_type_eid = subg.edata["orig_id"][idx]
inner_edge = subg.edata['inner_edge'][idx] inner_edge = subg.edata["inner_edge"][idx]
# All edges should have the same edge type. # All edges should have the same edge type.
assert np.all(etype_ids1.numpy() == int(etype_id)) assert np.all(etype_ids1.numpy() == int(etype_id))
assert np.all(np.sort(eid[inner_edge == 1].numpy()) == np.arange( assert np.all(
edge_map[etype][partid, 0], edge_map[etype][partid, 1])) np.sort(eid[inner_edge == 1].numpy())
== np.arange(edge_map[etype][partid, 0], edge_map[etype][partid, 1])
)
orig_edge_ids[etype].append(orig_type_eid[inner_edge == 1]) orig_edge_ids[etype].append(orig_type_eid[inner_edge == 1])
# Check edge data. # Check edge data.
for name in hg.edges[etype].data: for name in hg.edges[etype].data:
local_data = edge_feats[etype + '/' + name][type_eid] local_data = edge_feats[etype + "/" + name][type_eid]
local_data1 = hg.edges[etype].data[name][orig_type_eid] local_data1 = hg.edges[etype].data[name][orig_type_eid]
assert np.all(local_data.numpy() == local_data1.numpy()) assert np.all(local_data.numpy() == local_data1.numpy())
......
import dgl
import json import json
import torch as th
import dgl
import numpy as np import numpy as np
import torch as th
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
# Load OGB-MAG. # Load OGB-MAG.
dataset = DglNodePropPredDataset(name='ogbn-mag') dataset = DglNodePropPredDataset(name="ogbn-mag")
hg_orig, labels = dataset[0] hg_orig, labels = dataset[0]
subgs = {} subgs = {}
for etype in hg_orig.canonical_etypes: for etype in hg_orig.canonical_etypes:
u, v = hg_orig.all_edges(etype=etype) u, v = hg_orig.all_edges(etype=etype)
subgs[etype] = (u, v) subgs[etype] = (u, v)
subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u) subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
hg = dgl.heterograph(subgs) hg = dgl.heterograph(subgs)
hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat'] hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
print(hg) print(hg)
# OGB-MAG is stored in heterogeneous format. We need to convert it into homogeneous format. # OGB-MAG is stored in heterogeneous format. We need to convert it into homogeneous format.
g = dgl.to_homogeneous(hg) g = dgl.to_homogeneous(hg)
g.ndata['orig_id'] = g.ndata[dgl.NID] g.ndata["orig_id"] = g.ndata[dgl.NID]
g.edata['orig_id'] = g.edata[dgl.EID] g.edata["orig_id"] = g.edata[dgl.EID]
print('|V|=' + str(g.number_of_nodes())) print("|V|=" + str(g.number_of_nodes()))
print('|E|=' + str(g.number_of_edges())) print("|E|=" + str(g.number_of_edges()))
print('|NTYPE|=' + str(len(th.unique(g.ndata[dgl.NTYPE])))) print("|NTYPE|=" + str(len(th.unique(g.ndata[dgl.NTYPE]))))
# Store the metadata of nodes. # Store the metadata of nodes.
num_node_weights = 0 num_node_weights = 0
...@@ -30,15 +31,15 @@ node_data = [g.ndata[dgl.NTYPE].numpy()] ...@@ -30,15 +31,15 @@ node_data = [g.ndata[dgl.NTYPE].numpy()]
for ntype_id in th.unique(g.ndata[dgl.NTYPE]): for ntype_id in th.unique(g.ndata[dgl.NTYPE]):
node_data.append((g.ndata[dgl.NTYPE] == ntype_id).numpy()) node_data.append((g.ndata[dgl.NTYPE] == ntype_id).numpy())
num_node_weights += 1 num_node_weights += 1
node_data.append(g.ndata['orig_id'].numpy()) node_data.append(g.ndata["orig_id"].numpy())
node_data = np.stack(node_data, 1) node_data = np.stack(node_data, 1)
np.savetxt('mag_nodes.txt', node_data, fmt='%d', delimiter=' ') np.savetxt("mag_nodes.txt", node_data, fmt="%d", delimiter=" ")
# Store the node features # Store the node features
node_feats = {} node_feats = {}
for ntype in hg.ntypes: for ntype in hg.ntypes:
for name in hg.nodes[ntype].data: for name in hg.nodes[ntype].data:
node_feats[ntype + '/' + name] = hg.nodes[ntype].data[name] node_feats[ntype + "/" + name] = hg.nodes[ntype].data[name]
dgl.data.utils.save_tensors("node_feat.dgl", node_feats) dgl.data.utils.save_tensors("node_feat.dgl", node_feats)
# Store the metadata of edges. # Store the metadata of edges.
...@@ -50,11 +51,11 @@ self_loop_idx = src_id == dst_id ...@@ -50,11 +51,11 @@ self_loop_idx = src_id == dst_id
not_self_loop_idx = src_id != dst_id not_self_loop_idx = src_id != dst_id
self_loop_src_id = src_id[self_loop_idx] self_loop_src_id = src_id[self_loop_idx]
self_loop_dst_id = dst_id[self_loop_idx] self_loop_dst_id = dst_id[self_loop_idx]
self_loop_orig_id = g.edata['orig_id'][self_loop_idx] self_loop_orig_id = g.edata["orig_id"][self_loop_idx]
self_loop_etype = g.edata[dgl.ETYPE][self_loop_idx] self_loop_etype = g.edata[dgl.ETYPE][self_loop_idx]
src_id = src_id[not_self_loop_idx] src_id = src_id[not_self_loop_idx]
dst_id = dst_id[not_self_loop_idx] dst_id = dst_id[not_self_loop_idx]
orig_id = g.edata['orig_id'][not_self_loop_idx] orig_id = g.edata["orig_id"][not_self_loop_idx]
etype = g.edata[dgl.ETYPE][not_self_loop_idx] etype = g.edata[dgl.ETYPE][not_self_loop_idx]
# Remove duplicated edges. # Remove duplicated edges.
ids = (src_id * g.number_of_nodes() + dst_id).numpy() ids = (src_id * g.number_of_nodes() + dst_id).numpy()
...@@ -69,30 +70,38 @@ dst_id = dst_id[idx] ...@@ -69,30 +70,38 @@ dst_id = dst_id[idx]
orig_id = orig_id[idx] orig_id = orig_id[idx]
etype = etype[idx] etype = etype[idx]
edge_data = th.stack([src_id, dst_id, orig_id, etype], 1) edge_data = th.stack([src_id, dst_id, orig_id, etype], 1)
np.savetxt('mag_edges.txt', edge_data.numpy(), fmt='%d', delimiter=' ') np.savetxt("mag_edges.txt", edge_data.numpy(), fmt="%d", delimiter=" ")
removed_edge_data = th.stack([th.cat([self_loop_src_id, duplicate_src_id]), removed_edge_data = th.stack(
th.cat([self_loop_dst_id, duplicate_dst_id]), [
th.cat([self_loop_orig_id, duplicate_orig_id]), th.cat([self_loop_src_id, duplicate_src_id]),
th.cat([self_loop_etype, duplicate_etype])], th.cat([self_loop_dst_id, duplicate_dst_id]),
1) th.cat([self_loop_orig_id, duplicate_orig_id]),
np.savetxt('mag_removed_edges.txt', th.cat([self_loop_etype, duplicate_etype]),
removed_edge_data.numpy(), fmt='%d', delimiter=' ') ],
print('There are {} edges, remove {} self-loops and {} duplicated edges'.format(g.number_of_edges(), 1,
len(self_loop_src_id), )
len(duplicate_src_id))) np.savetxt(
"mag_removed_edges.txt", removed_edge_data.numpy(), fmt="%d", delimiter=" "
)
print(
"There are {} edges, remove {} self-loops and {} duplicated edges".format(
g.number_of_edges(), len(self_loop_src_id), len(duplicate_src_id)
)
)
# Store the edge features # Store the edge features
edge_feats = {} edge_feats = {}
for etype in hg.etypes: for etype in hg.etypes:
for name in hg.edges[etype].data: for name in hg.edges[etype].data:
edge_feats[etype + '/' + name] = hg.edges[etype].data[name] edge_feats[etype + "/" + name] = hg.edges[etype].data[name]
dgl.data.utils.save_tensors("edge_feat.dgl", edge_feats) dgl.data.utils.save_tensors("edge_feat.dgl", edge_feats)
# Store the basic metadata of the graph. # Store the basic metadata of the graph.
graph_stats = [g.number_of_nodes(), len(src_id), num_node_weights] graph_stats = [g.number_of_nodes(), len(src_id), num_node_weights]
with open('mag_stats.txt', 'w') as filehandle: with open("mag_stats.txt", "w") as filehandle:
filehandle.writelines("{} {} {}".format( filehandle.writelines(
graph_stats[0], graph_stats[1], graph_stats[2])) "{} {} {}".format(graph_stats[0], graph_stats[1], graph_stats[2])
)
# Store the ID ranges of nodes and edges of the entire graph. # Store the ID ranges of nodes and edges of the entire graph.
nid_ranges = {} nid_ranges = {}
...@@ -100,7 +109,7 @@ eid_ranges = {} ...@@ -100,7 +109,7 @@ eid_ranges = {}
for ntype in hg.ntypes: for ntype in hg.ntypes:
ntype_id = hg.get_ntype_id(ntype) ntype_id = hg.get_ntype_id(ntype)
nid = th.nonzero(g.ndata[dgl.NTYPE] == ntype_id, as_tuple=True)[0] nid = th.nonzero(g.ndata[dgl.NTYPE] == ntype_id, as_tuple=True)[0]
per_type_nid = g.ndata['orig_id'][nid] per_type_nid = g.ndata["orig_id"][nid]
assert np.all((per_type_nid == th.arange(len(per_type_nid))).numpy()) assert np.all((per_type_nid == th.arange(len(per_type_nid))).numpy())
assert np.all((nid == th.arange(nid[0], nid[-1] + 1)).numpy()) assert np.all((nid == th.arange(nid[0], nid[-1] + 1)).numpy())
nid_ranges[ntype] = [int(nid[0]), int(nid[-1] + 1)] nid_ranges[ntype] = [int(nid[0]), int(nid[-1] + 1)]
...@@ -109,5 +118,5 @@ for etype in hg.etypes: ...@@ -109,5 +118,5 @@ for etype in hg.etypes:
eid = th.nonzero(g.edata[dgl.ETYPE] == etype_id, as_tuple=True)[0] eid = th.nonzero(g.edata[dgl.ETYPE] == etype_id, as_tuple=True)[0]
assert np.all((eid == th.arange(eid[0], eid[-1] + 1)).numpy()) assert np.all((eid == th.arange(eid[0], eid[-1] + 1)).numpy())
eid_ranges[etype] = [int(eid[0]), int(eid[-1] + 1)] eid_ranges[etype] = [int(eid[0]), int(eid[-1] + 1)]
with open('mag.json', 'w') as outfile: with open("mag.json", "w") as outfile:
json.dump({'nid': nid_ranges, 'eid': eid_ranges}, outfile, indent=4) json.dump({"nid": nid_ranges, "eid": eid_ranges}, outfile, indent=4)
import dgl
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import tqdm
from dgl.data.knowledge_graph import FB15k237Dataset from dgl.data.knowledge_graph import FB15k237Dataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
import tqdm
# for building training/testing graphs # for building training/testing graphs
def get_subset_g(g, mask, num_rels, bidirected=False): def get_subset_g(g, mask, num_rels, bidirected=False):
src, dst = g.edges() src, dst = g.edges()
sub_src = src[mask] sub_src = src[mask]
sub_dst = dst[mask] sub_dst = dst[mask]
sub_rel = g.edata['etype'][mask] sub_rel = g.edata["etype"][mask]
if bidirected: if bidirected:
sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat([sub_dst, sub_src]) sub_src, sub_dst = torch.cat([sub_src, sub_dst]), torch.cat(
[sub_dst, sub_src]
)
sub_rel = torch.cat([sub_rel, sub_rel + num_rels]) sub_rel = torch.cat([sub_rel, sub_rel + num_rels])
sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes()) sub_g = dgl.graph((sub_src, sub_dst), num_nodes=g.num_nodes())
sub_g.edata[dgl.ETYPE] = sub_rel sub_g.edata[dgl.ETYPE] = sub_rel
return sub_g return sub_g
class GlobalUniform: class GlobalUniform:
def __init__(self, g, sample_size): def __init__(self, g, sample_size):
self.sample_size = sample_size self.sample_size = sample_size
...@@ -31,8 +35,9 @@ class GlobalUniform: ...@@ -31,8 +35,9 @@ class GlobalUniform:
def sample(self): def sample(self):
return torch.from_numpy(np.random.choice(self.eids, self.sample_size)) return torch.from_numpy(np.random.choice(self.eids, self.sample_size))
class NegativeSampler: class NegativeSampler:
def __init__(self, k=10): # negative sampling rate = 10 def __init__(self, k=10): # negative sampling rate = 10
self.k = k self.k = k
def sample(self, pos_samples, num_nodes): def sample(self, pos_samples, num_nodes):
...@@ -54,6 +59,7 @@ class NegativeSampler: ...@@ -54,6 +59,7 @@ class NegativeSampler:
return torch.from_numpy(samples), torch.from_numpy(labels) return torch.from_numpy(samples), torch.from_numpy(labels)
class SubgraphIterator: class SubgraphIterator:
def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000): def __init__(self, g, num_rels, sample_size=30000, num_epochs=6000):
self.g = g self.g = g
...@@ -82,9 +88,11 @@ class SubgraphIterator: ...@@ -82,9 +88,11 @@ class SubgraphIterator:
samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes) samples, labels = self.neg_sampler.sample(relabeled_data, num_nodes)
# use only half of the positive edges # use only half of the positive edges
chosen_ids = np.random.choice(np.arange(self.sample_size), chosen_ids = np.random.choice(
size=int(self.sample_size / 2), np.arange(self.sample_size),
replace=False) size=int(self.sample_size / 2),
replace=False,
)
src = src[chosen_ids] src = src[chosen_ids]
dst = dst[chosen_ids] dst = dst[chosen_ids]
rel = rel[chosen_ids] rel = rel[chosen_ids]
...@@ -92,42 +100,57 @@ class SubgraphIterator: ...@@ -92,42 +100,57 @@ class SubgraphIterator:
rel = np.concatenate((rel, rel + self.num_rels)) rel = np.concatenate((rel, rel + self.num_rels))
sub_g = dgl.graph((src, dst), num_nodes=num_nodes) sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel) sub_g.edata[dgl.ETYPE] = torch.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1) sub_g.edata["norm"] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = torch.from_numpy(uniq_v).view(-1).long() uniq_v = torch.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels return sub_g, uniq_v, samples, labels
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, num_rels): def __init__(self, num_nodes, h_dim, num_rels):
super().__init__() super().__init__()
# two-layer RGCN # two-layer RGCN
self.emb = nn.Embedding(num_nodes, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd', self.conv1 = RelGraphConv(
num_bases=100, self_loop=True) h_dim,
self.conv2 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='bdd', h_dim,
num_bases=100, self_loop=True) num_rels,
regularizer="bdd",
num_bases=100,
self_loop=True,
)
self.conv2 = RelGraphConv(
h_dim,
h_dim,
num_rels,
regularizer="bdd",
num_bases=100,
self_loop=True,
)
self.dropout = nn.Dropout(0.2) self.dropout = nn.Dropout(0.2)
def forward(self, g, nids): def forward(self, g, nids):
x = self.emb(nids) x = self.emb(nids)
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])) h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"]))
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm']) h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
return self.dropout(h) return self.dropout(h)
class LinkPredict(nn.Module): class LinkPredict(nn.Module):
def __init__(self, num_nodes, num_rels, h_dim = 500, reg_param=0.01): def __init__(self, num_nodes, num_rels, h_dim=500, reg_param=0.01):
super().__init__() super().__init__()
self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2) self.rgcn = RGCN(num_nodes, h_dim, num_rels * 2)
self.reg_param = reg_param self.reg_param = reg_param
self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim)) self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.w_relation, gain=nn.init.calculate_gain("relu")
)
def calc_score(self, embedding, triplets): def calc_score(self, embedding, triplets):
s = embedding[triplets[:,0]] s = embedding[triplets[:, 0]]
r = self.w_relation[triplets[:,1]] r = self.w_relation[triplets[:, 1]]
o = embedding[triplets[:,2]] o = embedding[triplets[:, 2]]
score = torch.sum(s * r * o, dim=1) score = torch.sum(s * r * o, dim=1)
return score return score
...@@ -144,7 +167,10 @@ class LinkPredict(nn.Module): ...@@ -144,7 +167,10 @@ class LinkPredict(nn.Module):
reg_loss = self.regularization_loss(embed) reg_loss = self.regularization_loss(embed)
return predict_loss + self.reg_param * reg_loss return predict_loss + self.reg_param * reg_loss
def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True):
def filter(
triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o=True
):
"""Get candidate heads or tails to score""" """Get candidate heads or tails to score"""
target_s, target_r, target_o = int(target_s), int(target_r), int(target_o) target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
# Add the ground truth node first # Add the ground truth node first
...@@ -153,13 +179,18 @@ def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o ...@@ -153,13 +179,18 @@ def filter(triplets_to_filter, target_s, target_r, target_o, num_nodes, filter_o
else: else:
candidate_nodes = [target_s] candidate_nodes = [target_s]
for e in range(num_nodes): for e in range(num_nodes):
triplet = (target_s, target_r, e) if filter_o else (e, target_r, target_o) triplet = (
(target_s, target_r, e) if filter_o else (e, target_r, target_o)
)
# Do not consider a node if it leads to a real triplet # Do not consider a node if it leads to a real triplet
if triplet not in triplets_to_filter: if triplet not in triplets_to_filter:
candidate_nodes.append(e) candidate_nodes.append(e)
return torch.LongTensor(candidate_nodes) return torch.LongTensor(candidate_nodes)
def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True):
def perturb_and_get_filtered_rank(
emb, w, s, r, o, test_size, triplets_to_filter, filter_o=True
):
"""Perturb subject or object in the triplets""" """Perturb subject or object in the triplets"""
num_nodes = emb.shape[0] num_nodes = emb.shape[0]
ranks = [] ranks = []
...@@ -167,8 +198,14 @@ def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter ...@@ -167,8 +198,14 @@ def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter
target_s = s[idx] target_s = s[idx]
target_r = r[idx] target_r = r[idx]
target_o = o[idx] target_o = o[idx]
candidate_nodes = filter(triplets_to_filter, target_s, target_r, candidate_nodes = filter(
target_o, num_nodes, filter_o=filter_o) triplets_to_filter,
target_s,
target_r,
target_o,
num_nodes,
filter_o=filter_o,
)
if filter_o: if filter_o:
emb_s = emb[target_s] emb_s = emb[target_s]
emb_o = emb[candidate_nodes] emb_o = emb[candidate_nodes]
...@@ -185,25 +222,42 @@ def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter ...@@ -185,25 +222,42 @@ def perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, triplets_to_filter
ranks.append(rank) ranks.append(rank)
return torch.LongTensor(ranks) return torch.LongTensor(ranks)
def calc_mrr(emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True):
def calc_mrr(
emb, w, test_mask, triplets_to_filter, batch_size=100, filter=True
):
with torch.no_grad(): with torch.no_grad():
test_triplets = triplets_to_filter[test_mask] test_triplets = triplets_to_filter[test_mask]
s, r, o = test_triplets[:,0], test_triplets[:,1], test_triplets[:,2] s, r, o = test_triplets[:, 0], test_triplets[:, 1], test_triplets[:, 2]
test_size = len(s) test_size = len(s)
triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter.tolist()} triplets_to_filter = {
ranks_s = perturb_and_get_filtered_rank(emb, w, s, r, o, test_size, tuple(triplet) for triplet in triplets_to_filter.tolist()
triplets_to_filter, filter_o=False) }
ranks_o = perturb_and_get_filtered_rank(emb, w, s, r, o, ranks_s = perturb_and_get_filtered_rank(
test_size, triplets_to_filter) emb, w, s, r, o, test_size, triplets_to_filter, filter_o=False
)
ranks_o = perturb_and_get_filtered_rank(
emb, w, s, r, o, test_size, triplets_to_filter
)
ranks = torch.cat([ranks_s, ranks_o]) ranks = torch.cat([ranks_s, ranks_o])
ranks += 1 # change to 1-indexed ranks += 1 # change to 1-indexed
mrr = torch.mean(1.0 / ranks.float()).item() mrr = torch.mean(1.0 / ranks.float()).item()
return mrr return mrr
def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model):
def train(
dataloader,
test_g,
test_nids,
test_mask,
triplets,
device,
model_state_file,
model,
):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
best_mrr = 0 best_mrr = 0
for epoch, batch_data in enumerate(dataloader): # single graph batch for epoch, batch_data in enumerate(dataloader): # single graph batch
model.train() model.train()
g, train_nids, edges, labels = batch_data g, train_nids, edges, labels = batch_data
g = g.to(device) g = g.to(device)
...@@ -215,57 +269,84 @@ def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_stat ...@@ -215,57 +269,84 @@ def train(dataloader, test_g, test_nids, test_mask, triplets, device, model_stat
loss = model.get_loss(embed, edges, labels) loss = model.get_loss(embed, edges, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
) # clip gradients
optimizer.step() optimizer.step()
print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(epoch, loss.item(), best_mrr)) print(
"Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f}".format(
epoch, loss.item(), best_mrr
)
)
if (epoch + 1) % 500 == 0: if (epoch + 1) % 500 == 0:
# perform validation on CPU because full graph is too large # perform validation on CPU because full graph is too large
model = model.cpu() model = model.cpu()
model.eval() model.eval()
embed = model(test_g, test_nids) embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets, mrr = calc_mrr(
batch_size=500) embed, model.w_relation, test_mask, triplets, batch_size=500
)
# save best model # save best model
if best_mrr < mrr: if best_mrr < mrr:
best_mrr = mrr best_mrr = mrr
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, model_state_file) torch.save(
{"state_dict": model.state_dict(), "epoch": epoch},
model_state_file,
)
model = model.to(device) model = model.to(device)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if __name__ == "__main__":
print(f'Training with DGL built-in RGCN module') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training with DGL built-in RGCN module")
# load and preprocess dataset # load and preprocess dataset
data = FB15k237Dataset(reverse=False) data = FB15k237Dataset(reverse=False)
g = data[0] g = data[0]
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
num_rels = data.num_rels num_rels = data.num_rels
train_g = get_subset_g(g, g.edata['train_mask'], num_rels) train_g = get_subset_g(g, g.edata["train_mask"], num_rels)
test_g = get_subset_g(g, g.edata['train_mask'], num_rels, bidirected=True) test_g = get_subset_g(g, g.edata["train_mask"], num_rels, bidirected=True)
test_g.edata['norm'] = dgl.norm_by_dst(test_g).unsqueeze(-1) test_g.edata["norm"] = dgl.norm_by_dst(test_g).unsqueeze(-1)
test_nids = torch.arange(0, num_nodes) test_nids = torch.arange(0, num_nodes)
test_mask = g.edata['test_mask'] test_mask = g.edata["test_mask"]
subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling subg_iter = SubgraphIterator(train_g, num_rels) # uniform edge sampling
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0]) dataloader = GraphDataLoader(
subg_iter, batch_size=1, collate_fn=lambda x: x[0]
)
# Prepare data for metric computation # Prepare data for metric computation
src, dst = g.edges() src, dst = g.edges()
triplets = torch.stack([src, g.edata['etype'], dst], dim=1) triplets = torch.stack([src, g.edata["etype"], dst], dim=1)
# create RGCN model # create RGCN model
model = LinkPredict(num_nodes, num_rels).to(device) model = LinkPredict(num_nodes, num_rels).to(device)
# train # train
model_state_file = 'model_state.pth' model_state_file = "model_state.pth"
train(dataloader, test_g, test_nids, test_mask, triplets, device, model_state_file, model) train(
dataloader,
test_g,
test_nids,
test_mask,
triplets,
device,
model_state_file,
model,
)
# testing # testing
print("Testing...") print("Testing...")
checkpoint = torch.load(model_state_file) checkpoint = torch.load(model_state_file)
model = model.cpu() # test on CPU model = model.cpu() # test on CPU
model.eval() model.eval()
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint["state_dict"])
embed = model(test_g, test_nids) embed = model(test_g, test_nids)
best_mrr = calc_mrr(embed, model.w_relation, test_mask, triplets, best_mrr = calc_mrr(
batch_size=500) embed, model.w_relation, test_mask, triplets, batch_size=500
print("Best MRR {:.4f} achieved using the epoch {:04d}".format(best_mrr, checkpoint['epoch'])) )
print(
"Best MRR {:.4f} achieved using the epoch {:04d}".format(
best_mrr, checkpoint["epoch"]
)
)
from dgl import DGLGraph import dgl
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl from dgl import DGLGraph
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, def __init__(
regularizer="basis", num_bases=-1, dropout=0., self,
self_loop=False, num_nodes,
ns_mode=False): h_dim,
out_dim,
num_rels,
regularizer="basis",
num_bases=-1,
dropout=0.0,
self_loop=False,
ns_mode=False,
):
super(RGCN, self).__init__() super(RGCN, self).__init__()
if num_bases == -1: if num_bases == -1:
num_bases = num_rels num_bases = num_rels
self.emb = nn.Embedding(num_nodes, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer, self.conv1 = RelGraphConv(
num_bases, self_loop=self_loop) h_dim, h_dim, num_rels, regularizer, num_bases, self_loop=self_loop
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer, num_bases, self_loop=self_loop) )
self.conv2 = RelGraphConv(
h_dim,
out_dim,
num_rels,
regularizer,
num_bases,
self_loop=self_loop,
)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.ns_mode = ns_mode self.ns_mode = ns_mode
...@@ -26,13 +43,13 @@ class RGCN(nn.Module): ...@@ -26,13 +43,13 @@ class RGCN(nn.Module):
if self.ns_mode: if self.ns_mode:
# forward for neighbor sampling # forward for neighbor sampling
x = self.emb(g[0].srcdata[dgl.NID]) x = self.emb(g[0].srcdata[dgl.NID])
h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']) h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata["norm"])
h = self.dropout(F.relu(h)) h = self.dropout(F.relu(h))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm']) h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata["norm"])
return h return h
else: else:
x = self.emb.weight if nids is None else self.emb(nids) x = self.emb.weight if nids is None else self.emb(nids)
h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']) h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata["norm"])
h = self.dropout(F.relu(h)) h = self.dropout(F.relu(h))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm']) h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata["norm"])
return h return h
...@@ -7,11 +7,10 @@ References: ...@@ -7,11 +7,10 @@ References:
- Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks - Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks
""" """
import dgl.function as fn
import torch import torch
from torch import nn from torch import nn
import dgl.function as fn
class RRNLayer(nn.Module): class RRNLayer(nn.Module):
def __init__(self, msg_layer, node_update_func, edge_drop): def __init__(self, msg_layer, node_update_func, edge_drop):
......
...@@ -4,13 +4,13 @@ import urllib.request ...@@ -4,13 +4,13 @@ import urllib.request
import zipfile import zipfile
from copy import copy from copy import copy
import dgl
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
import dgl
def _basic_sudoku_graph(): def _basic_sudoku_graph():
grids = [ grids = [
......
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from utils import get_batch_id, topk
import dgl
from dgl.nn import AvgPooling, GraphConv, MaxPooling from dgl.nn import AvgPooling, GraphConv, MaxPooling
from utils import get_batch_id, topk
class SAGPool(torch.nn.Module): class SAGPool(torch.nn.Module):
......
...@@ -4,17 +4,17 @@ import logging ...@@ -4,17 +4,17 @@ import logging
import os import os
from time import time from time import time
import dgl
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from network import get_sag_network from network import get_sag_network
from torch.utils.data import random_split from torch.utils.data import random_split
from utils import get_stats from utils import get_stats
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling") parser = argparse.ArgumentParser(description="Self-Attention Graph Pooling")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment