Unverified Commit f19f05ce authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4651)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 977b1ba4
import torch as th import torch as th
import dgl import dgl
class NegativeSampler(object): class NegativeSampler(object):
def __init__(self, g, k, neg_share=False, device=None): def __init__(self, g, k, neg_share=False, device=None):
if device is None: if device is None:
...@@ -16,6 +18,6 @@ class NegativeSampler(object): ...@@ -16,6 +18,6 @@ class NegativeSampler(object):
dst = self.weights.multinomial(n, replacement=True) dst = self.weights.multinomial(n, replacement=True)
dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten() dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
else: else:
dst = self.weights.multinomial(n*self.k, replacement=True) dst = self.weights.multinomial(n * self.k, replacement=True)
src = src.repeat_interleave(self.k) src = src.repeat_interleave(self.k)
return src, dst return src, dst
import argparse
import time
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import tqdm
from ogb.nodeproppred import DglNodePropPredDataset
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
...@@ -33,20 +36,31 @@ class SAGE(nn.Module): ...@@ -33,20 +36,31 @@ class SAGE(nn.Module):
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official # The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat'] feat = g.ndata["feat"]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["feat"]
)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g,
batch_size=batch_size, shuffle=False, drop_last=False, torch.arange(g.num_nodes()).to(g.device),
num_workers=num_workers) sampler,
device=device,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
)
if buffer_device is None: if buffer_device is None:
buffer_device = device buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.empty( y = torch.empty(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, g.num_nodes(),
device=buffer_device, pin_memory=True) self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device,
pin_memory=True,
)
feat = feat.to(device) feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
# use an explicitly contuous slice # use an explicitly contuous slice
...@@ -57,44 +71,64 @@ class SAGE(nn.Module): ...@@ -57,44 +71,64 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
# be design, our output nodes are contiguous so we can take # be design, our output nodes are contiguous so we can take
# advantage of that here # advantage of that here
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device) y[output_nodes[0] : output_nodes[-1] + 1] = h.to(buffer_device)
feat = y feat = y
return y return y
dataset = DglNodePropPredDataset("ogbn-products")
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0] graph, labels = dataset[0]
graph.ndata['label'] = labels.squeeze() graph.ndata["label"] = labels.squeeze()
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, valid_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
device = 'cuda' device = "cuda"
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device) valid_idx = valid_idx.to(device)
test_idx = test_idx.to(device) test_idx = test_idx.to(device)
graph = graph.to(device) graph = graph.to(device)
model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device) model = SAGE(graph.ndata["feat"].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label']) [15, 10, 5], prefetch_node_feats=["feat"], prefetch_labels=["label"]
)
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True, graph,
drop_last=False, num_workers=0, use_uva=False) train_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=False,
)
valid_dataloader = dgl.dataloading.DataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True, graph,
drop_last=False, num_workers=0, use_uva=False) valid_idx,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=False,
)
durations = [] durations = []
for _ in range(10): for _ in range(10):
model.train() model.train()
t0 = time.time() t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata['label'] y = blocks[-1].dstdata["label"]
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
...@@ -103,7 +137,7 @@ for _ in range(10): ...@@ -103,7 +137,7 @@ for _ in range(10):
if it % 20 == 0: if it % 20 == 0:
acc = MF.accuracy(torch.argmax(y_hat, dim=1), y) acc = MF.accuracy(torch.argmax(y_hat, dim=1), y)
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB') print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() tt = time.time()
print(tt - t0) print(tt - t0)
durations.append(tt - t0) durations.append(tt - t0)
...@@ -113,19 +147,19 @@ for _ in range(10): ...@@ -113,19 +147,19 @@ for _ in range(10):
y_hats = [] y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader): for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
with torch.no_grad(): with torch.no_grad():
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata['label']) ys.append(blocks[-1].dstdata["label"])
y_hats.append(torch.argmax(model(blocks, x), dim=1)) y_hats.append(torch.argmax(model(blocks, x), dim=1))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys)) acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
print('Validation acc:', acc.item()) print("Validation acc:", acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
# Test accuracy and offline inference of all nodes # Test accuracy and offline inference of all nodes
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, 4096, 0, 'cpu') pred = model.inference(graph, device, 4096, 0, "cpu")
pred = pred[test_idx].to(device) pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx] label = graph.ndata["label"][test_idx]
acc = MF.accuracy(torch.argmax(pred, dim=1), label) acc = MF.accuracy(torch.argmax(pred, dim=1), label)
print('Test acc:', acc.item()) print("Test acc:", acc.item())
import dgl
import numpy as np
import torch as th
import argparse import argparse
import time
import sys
import os import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import sys
from load_graph import load_reddit, load_ogb import time
import numpy as np
import torch as th
import dgl
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import load_ogb, load_reddit
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser("Partition builtin graphs") argparser = argparse.ArgumentParser("Partition builtin graphs")
argparser.add_argument('--dataset', type=str, default='reddit', argparser.add_argument(
help='datasets: reddit, ogb-product, ogb-paper100M') "--dataset",
argparser.add_argument('--num_parts', type=int, default=4, type=str,
help='number of partitions') default="reddit",
argparser.add_argument('--part_method', type=str, default='metis', help="datasets: reddit, ogb-product, ogb-paper100M",
help='the partition method') )
argparser.add_argument('--balance_train', action='store_true', argparser.add_argument(
help='balance the training size in each partition.') "--num_parts", type=int, default=4, help="number of partitions"
argparser.add_argument('--undirected', action='store_true', )
help='turn the graph into an undirected graph.') argparser.add_argument(
argparser.add_argument('--balance_edges', action='store_true', "--part_method", type=str, default="metis", help="the partition method"
help='balance the number of edges in each partition.') )
argparser.add_argument('--num_trainers_per_machine', type=int, default=1, argparser.add_argument(
help='the number of trainers per machine. The trainer ids are stored\ "--balance_train",
in the node feature \'trainer_id\'') action="store_true",
argparser.add_argument('--output', type=str, default='data', help="balance the training size in each partition.",
help='Output path of partitioned graph.') )
argparser.add_argument(
"--undirected",
action="store_true",
help="turn the graph into an undirected graph.",
)
argparser.add_argument(
"--balance_edges",
action="store_true",
help="balance the number of edges in each partition.",
)
argparser.add_argument(
"--num_trainers_per_machine",
type=int,
default=1,
help="the number of trainers per machine. The trainer ids are stored\
in the node feature 'trainer_id'",
)
argparser.add_argument(
"--output",
type=str,
default="data",
help="Output path of partitioned graph.",
)
args = argparser.parse_args() args = argparser.parse_args()
start = time.time() start = time.time()
if args.dataset == 'reddit': if args.dataset == "reddit":
g, _ = load_reddit() g, _ = load_reddit()
elif args.dataset == 'ogb-product': elif args.dataset == "ogb-product":
g, _ = load_ogb('ogbn-products') g, _ = load_ogb("ogbn-products")
elif args.dataset == 'ogb-paper100M': elif args.dataset == "ogb-paper100M":
g, _ = load_ogb('ogbn-papers100M') g, _ = load_ogb("ogbn-papers100M")
print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start)) print(
print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges())) "load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
print('train: {}, valid: {}, test: {}'.format(th.sum(g.ndata['train_mask']), )
th.sum(g.ndata['val_mask']), print("|V|={}, |E|={}".format(g.number_of_nodes(), g.number_of_edges()))
th.sum(g.ndata['test_mask']))) print(
"train: {}, valid: {}, test: {}".format(
th.sum(g.ndata["train_mask"]),
th.sum(g.ndata["val_mask"]),
th.sum(g.ndata["test_mask"]),
)
)
if args.balance_train: if args.balance_train:
balance_ntypes = g.ndata['train_mask'] balance_ntypes = g.ndata["train_mask"]
else: else:
balance_ntypes = None balance_ntypes = None
...@@ -52,8 +84,13 @@ if __name__ == '__main__': ...@@ -52,8 +84,13 @@ if __name__ == '__main__':
sym_g.ndata[key] = g.ndata[key] sym_g.ndata[key] = g.ndata[key]
g = sym_g g = sym_g
dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output, dgl.distributed.partition_graph(
part_method=args.part_method, g,
balance_ntypes=balance_ntypes, args.dataset,
balance_edges=args.balance_edges, args.num_parts,
num_trainers_per_machine=args.num_trainers_per_machine) args.output,
part_method=args.part_method,
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges,
num_trainers_per_machine=args.num_trainers_per_machine,
)
import os import os
os.environ['DGLBACKEND']='pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import argparse
import math
import socket
import time
from functools import wraps
from multiprocessing import Process from multiprocessing import Process
import argparse, time, math
import numpy as np import numpy as np
from functools import wraps import torch as th
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm import tqdm
from torch.utils.data import DataLoader
import dgl import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.data.utils import load_graphs
from dgl.distributed import DistDataLoader from dgl.distributed import DistDataLoader
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import socket
def load_subtensor(g, seeds, input_nodes, device, load_feat=True): def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
batch_inputs = g.ndata['features'][input_nodes].to(device) if load_feat else None batch_inputs = (
batch_labels = g.ndata['labels'][seeds].to(device) g.ndata["features"][input_nodes].to(device) if load_feat else None
)
batch_labels = g.ndata["labels"][seeds].to(device)
return batch_inputs, batch_labels return batch_inputs, batch_labels
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True): def __init__(self, g, fanouts, sample_neighbors, device, load_feat=True):
self.g = g self.g = g
self.fanouts = fanouts self.fanouts = fanouts
self.sample_neighbors = sample_neighbors self.sample_neighbors = sample_neighbors
self.device = device self.device = device
self.load_feat=load_feat self.load_feat = load_feat
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds)) seeds = th.LongTensor(np.asarray(seeds))
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True) frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True
)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
...@@ -53,24 +62,28 @@ class NeighborSampler(object): ...@@ -53,24 +62,28 @@ class NeighborSampler(object):
input_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID] seeds = blocks[-1].dstdata[dgl.NID]
batch_inputs, batch_labels = load_subtensor(self.g, seeds, input_nodes, "cpu", self.load_feat) batch_inputs, batch_labels = load_subtensor(
self.g, seeds, input_nodes, "cpu", self.load_feat
)
if self.load_feat: if self.load_feat:
blocks[0].srcdata['features'] = batch_inputs blocks[0].srcdata["features"] = batch_inputs
blocks[-1].dstdata['labels'] = batch_labels blocks[-1].dstdata["labels"] = batch_labels
return blocks return blocks
class DistSAGE(nn.Module): class DistSAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, def __init__(
activation, dropout): self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
...@@ -97,31 +110,49 @@ class DistSAGE(nn.Module): ...@@ -97,31 +110,49 @@ class DistSAGE(nn.Module):
# Therefore, we compute the representation of all nodes layer by layer. The nodes # Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches. # on each layer are of course splitted in batches.
# TODO: can we standardize this? # TODO: can we standardize this?
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()), nodes = dgl.distributed.node_split(
g.get_partition_book(), force_even=True) np.arange(g.number_of_nodes()),
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_hidden), th.float32, 'h', g.get_partition_book(),
persistent=True) force_even=True,
)
y = dgl.distributed.DistTensor(
(g.number_of_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
if l == len(self.layers) - 1: if l == len(self.layers) - 1:
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_classes), y = dgl.distributed.DistTensor(
th.float32, 'h_last', persistent=True) (g.number_of_nodes(), self.n_classes),
th.float32,
sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors, device) "h_last",
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size)) persistent=True,
)
sampler = NeighborSampler(
g, [-1], dgl.distributed.sample_neighbors, device
)
print(
"|V|={}, eval batch size: {}".format(
g.number_of_nodes(), batch_size
)
)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
dataset=nodes, dataset=nodes,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
)
for blocks in tqdm.tqdm(dataloader): for blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID] output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
...@@ -133,6 +164,7 @@ class DistSAGE(nn.Module): ...@@ -133,6 +164,7 @@ class DistSAGE(nn.Module):
g.barrier() g.barrier()
return y return y
def compute_acc(pred, labels): def compute_acc(pred, labels):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -140,6 +172,7 @@ def compute_acc(pred, labels): ...@@ -140,6 +172,7 @@ def compute_acc(pred, labels):
labels = labels.long() labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
""" """
Evaluate the model on the validation set specified by ``val_nid``. Evaluate the model on the validation set specified by ``val_nid``.
...@@ -154,7 +187,9 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device): ...@@ -154,7 +187,9 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
with th.no_grad(): with th.no_grad():
pred = model.inference(g, inputs, batch_size, device) pred = model.inference(g, inputs, batch_size, device)
model.train() model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]) return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
pred[test_nid], labels[test_nid]
)
def run(args, device, data): def run(args, device, data):
...@@ -162,8 +197,12 @@ def run(args, device, data): ...@@ -162,8 +197,12 @@ def run(args, device, data):
train_nid, val_nid, test_nid, in_feats, n_classes, g = data train_nid, val_nid, test_nid, in_feats, n_classes, g = data
shuffle = True shuffle = True
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(
dgl.distributed.sample_neighbors, device) g,
[int(fanout) for fanout in args.fan_out.split(",")],
dgl.distributed.sample_neighbors,
device,
)
# Create DataLoader for constructing blocks # Create DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
...@@ -171,16 +210,26 @@ def run(args, device, data): ...@@ -171,16 +210,26 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=shuffle, shuffle=shuffle,
drop_last=False) drop_last=False,
)
# Define model and optimizer # Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = DistSAGE(
in_feats,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
if args.num_gpus == -1: if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model) model = th.nn.parallel.DistributedDataParallel(model)
else: else:
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) model = th.nn.parallel.DistributedDataParallel(
model, device_ids=[device], output_device=device
)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -209,8 +258,8 @@ def run(args, device, data): ...@@ -209,8 +258,8 @@ def run(args, device, data):
# The nodes for input lies at the LHS side of the first block. # The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
batch_inputs = blocks[0].srcdata['features'] batch_inputs = blocks[0].srcdata["features"]
batch_labels = blocks[-1].dstdata['labels'] batch_labels = blocks[-1].dstdata["labels"]
batch_labels = batch_labels.long() batch_labels = batch_labels.long()
num_seeds += len(blocks[-1].dstdata[dgl.NID]) num_seeds += len(blocks[-1].dstdata[dgl.NID])
...@@ -219,7 +268,7 @@ def run(args, device, data): ...@@ -219,7 +268,7 @@ def run(args, device, data):
batch_labels = batch_labels.to(device) batch_labels = batch_labels.to(device)
# Compute loss and prediction # Compute loss and prediction
start = time.time() start = time.time()
#print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id) # print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id)
batch_pred = model(blocks, batch_inputs) batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels) loss = loss_fcn(batch_pred, batch_labels)
forward_end = time.time() forward_end = time.time()
...@@ -237,102 +286,191 @@ def run(args, device, data): ...@@ -237,102 +286,191 @@ def run(args, device, data):
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t) iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format( th.cuda.max_memory_allocated() / 1000000
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:]))) if th.cuda.is_available()
else 0
)
print(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s".format(
g.rank(),
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
np.sum(step_time[-args.log_every :]),
)
)
start = time.time() start = time.time()
toc = time.time() toc = time.time()
print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format( print(
g.rank(), toc - tic, sample_time, forward_time, backward_time, update_time, num_seeds, num_inputs)) "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format(
g.rank(),
toc - tic,
sample_time,
forward_time,
backward_time,
update_time,
num_seeds,
num_inputs,
)
)
epoch += 1 epoch += 1
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
start = time.time() start = time.time()
val_acc, test_acc = evaluate(model.module, g, g.ndata['features'], val_acc, test_acc = evaluate(
g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device) model.module,
print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc, g,
time.time() - start)) g.ndata["features"],
g.ndata["labels"],
val_nid,
test_nid,
args.batch_size_eval,
device,
)
print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
g.rank(), val_acc, test_acc, time.time() - start
)
)
def main(args): def main(args):
print(socket.gethostname(), 'Initializing DGL dist') print(socket.gethostname(), "Initializing DGL dist")
dgl.distributed.initialize(args.ip_config, net_type=args.net_type) dgl.distributed.initialize(args.ip_config, net_type=args.net_type)
if not args.standalone: if not args.standalone:
print(socket.gethostname(), 'Initializing DGL process group') print(socket.gethostname(), "Initializing DGL process group")
th.distributed.init_process_group(backend=args.backend) th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), 'Initializing DistGraph') print(socket.gethostname(), "Initializing DistGraph")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print(socket.gethostname(), 'rank:', g.rank()) print(socket.gethostname(), "rank:", g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
if 'trainer_id' in g.ndata: if "trainer_id" in g.ndata:
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True, train_nid = dgl.distributed.node_split(
node_trainer_ids=g.ndata['trainer_id']) g.ndata["train_mask"],
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True, pb,
node_trainer_ids=g.ndata['trainer_id']) force_even=True,
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True, node_trainer_ids=g.ndata["trainer_id"],
node_trainer_ids=g.ndata['trainer_id']) )
val_nid = dgl.distributed.node_split(
g.ndata["val_mask"],
pb,
force_even=True,
node_trainer_ids=g.ndata["trainer_id"],
)
test_nid = dgl.distributed.node_split(
g.ndata["test_mask"],
pb,
force_even=True,
node_trainer_ids=g.ndata["trainer_id"],
)
else: else:
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True) train_nid = dgl.distributed.node_split(
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True) g.ndata["train_mask"], pb, force_even=True
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True) )
val_nid = dgl.distributed.node_split(
g.ndata["val_mask"], pb, force_even=True
)
test_nid = dgl.distributed.node_split(
g.ndata["test_mask"], pb, force_even=True
)
local_nid = pb.partid2nids(pb.partid).detach().numpy() local_nid = pb.partid2nids(pb.partid).detach().numpy()
print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format( print(
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)), "part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format(
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)), g.rank(),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid)))) len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
)
del local_nid del local_nid
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device("cpu")
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id)) device = th.device("cuda:" + str(dev_id))
n_classes = args.n_classes n_classes = args.n_classes
if n_classes == -1: if n_classes == -1:
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] labels = g.ndata["labels"][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
del labels del labels
print('#labels:', n_classes) print("#labels:", n_classes)
# Pack data # Pack data
in_feats = g.ndata['features'].shape[1] in_feats = g.ndata["features"].shape[1]
data = train_nid, val_nid, test_nid, in_feats, n_classes, g data = train_nid, val_nid, test_nid, in_feats, n_classes, g
run(args, device, data) run(args, device, data)
print("parent ends") print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument('--graph_name', type=str, help='graph name') parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument("--id", type=int, help="the partition id")
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument(
parser.add_argument('--part_config', type=str, help='The path to the partition config file') "--ip_config", type=str, help="The file for IP configuration"
parser.add_argument('--num_clients', type=int, help='The number of clients') )
parser.add_argument('--n_classes', type=int, default=-1, parser.add_argument(
help='The number of classes. If not specified, this' "--part_config", type=str, help="The path to the partition config file"
' value will be calculated via scaning all the labels' )
' in the dataset which probably causes memory burst.') parser.add_argument("--num_clients", type=int, help="The number of clients")
parser.add_argument('--backend', type=str, default='gloo', parser.add_argument(
help='pytorch distributed backend') "--n_classes",
parser.add_argument('--num_gpus', type=int, default=-1, type=int,
help="the number of GPU device. Use -1 for CPU training") default=-1,
parser.add_argument('--num_epochs', type=int, default=20) help="The number of classes. If not specified, this"
parser.add_argument('--num_hidden', type=int, default=16) " value will be calculated via scaning all the labels"
parser.add_argument('--num_layers', type=int, default=2) " in the dataset which probably causes memory burst.",
parser.add_argument('--fan_out', type=str, default='10,25') )
parser.add_argument('--batch_size', type=int, default=1000) parser.add_argument(
parser.add_argument('--batch_size_eval', type=int, default=100000) "--backend",
parser.add_argument('--log_every', type=int, default=20) type=str,
parser.add_argument('--eval_every', type=int, default=5) default="gloo",
parser.add_argument('--lr', type=float, default=0.003) help="pytorch distributed backend",
parser.add_argument('--dropout', type=float, default=0.5) )
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument(
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') "--num_gpus",
parser.add_argument('--pad-data', default=False, action='store_true', type=int,
help='Pad train nid to the same length across machine, to ensure num of batches to be the same.') default=-1,
parser.add_argument('--net_type', type=str, default='socket', help="the number of GPU device. Use -1 for CPU training",
help="backend net type, 'socket' or 'tensorpipe'") )
parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument("--num_hidden", type=int, default=16)
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--fan_out", type=str, default="10,25")
parser.add_argument("--batch_size", type=int, default=1000)
parser.add_argument("--batch_size_eval", type=int, default=100000)
parser.add_argument("--log_every", type=int, default=20)
parser.add_argument("--eval_every", type=int, default=5)
parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument(
"--local_rank", type=int, help="get rank of the process"
)
parser.add_argument(
"--standalone", action="store_true", help="run in the standalone mode"
)
parser.add_argument(
"--pad-data",
default=False,
action="store_true",
help="Pad train nid to the same length across machine, to ensure num of batches to be the same.",
)
parser.add_argument(
"--net_type",
type=str,
default="socket",
help="backend net type, 'socket' or 'tensorpipe'",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import os import os
os.environ['DGLBACKEND']='pytorch'
from multiprocessing import Process
import argparse, time, math
import numpy as np
from functools import wraps
import tqdm
import dgl os.environ["DGLBACKEND"] = "pytorch"
from dgl import DGLGraph import argparse
from dgl.data import register_data_args, load_data import math
from dgl.data.utils import load_graphs import time
import dgl.function as fn from functools import wraps
import dgl.nn.pytorch as dglnn from multiprocessing import Process
from dgl.distributed import DistDataLoader
from dgl.distributed import DistEmbedding
import numpy as np
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import tqdm
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from train_dist import DistSAGE, NeighborSampler, compute_acc from train_dist import DistSAGE, NeighborSampler, compute_acc
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.data.utils import load_graphs
from dgl.distributed import DistDataLoader, DistEmbedding
class TransDistSAGE(DistSAGE): class TransDistSAGE(DistSAGE):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, def __init__(
activation, dropout): self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
super(TransDistSAGE, self).__init__(in_feats, n_hidden, n_classes, n_layers, activation, dropout) ):
super(TransDistSAGE, self).__init__(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
def inference(self, standalone, g, x, batch_size, device): def inference(self, standalone, g, x, batch_size, device):
""" """
...@@ -43,31 +48,53 @@ class TransDistSAGE(DistSAGE): ...@@ -43,31 +48,53 @@ class TransDistSAGE(DistSAGE):
# Therefore, we compute the representation of all nodes layer by layer. The nodes # Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches. # on each layer are of course splitted in batches.
# TODO: can we standardize this? # TODO: can we standardize this?
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()), nodes = dgl.distributed.node_split(
g.get_partition_book(), force_even=True) np.arange(g.number_of_nodes()),
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_hidden), th.float32, 'h', g.get_partition_book(),
persistent=True) force_even=True,
)
y = dgl.distributed.DistTensor(
(g.number_of_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
if l == len(self.layers) - 1: if l == len(self.layers) - 1:
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_classes), y = dgl.distributed.DistTensor(
th.float32, 'h_last', persistent=True) (g.number_of_nodes(), self.n_classes),
th.float32,
sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors, device, load_feat=False) "h_last",
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size)) persistent=True,
)
sampler = NeighborSampler(
g,
[-1],
dgl.distributed.sample_neighbors,
device,
load_feat=False,
)
print(
"|V|={}, eval batch size: {}".format(
g.number_of_nodes(), batch_size
)
)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
dataset=nodes, dataset=nodes,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
)
for blocks in tqdm.tqdm(dataloader): for blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID] output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
...@@ -79,19 +106,23 @@ class TransDistSAGE(DistSAGE): ...@@ -79,19 +106,23 @@ class TransDistSAGE(DistSAGE):
g.barrier() g.barrier()
return y return y
def initializer(shape, dtype): def initializer(shape, dtype):
arr = th.zeros(shape, dtype=dtype) arr = th.zeros(shape, dtype=dtype)
arr.uniform_(-1, 1) arr.uniform_(-1, 1)
return arr return arr
class DistEmb(nn.Module): class DistEmb(nn.Module):
def __init__(self, num_nodes, emb_size, dgl_sparse_emb=False, dev_id='cpu'): def __init__(self, num_nodes, emb_size, dgl_sparse_emb=False, dev_id="cpu"):
super().__init__() super().__init__()
self.dev_id = dev_id self.dev_id = dev_id
self.emb_size = emb_size self.emb_size = emb_size
self.dgl_sparse_emb = dgl_sparse_emb self.dgl_sparse_emb = dgl_sparse_emb
if dgl_sparse_emb: if dgl_sparse_emb:
self.sparse_emb = DistEmbedding(num_nodes, emb_size, name='sage', init_func=initializer) self.sparse_emb = DistEmbedding(
num_nodes, emb_size, name="sage", init_func=initializer
)
else: else:
self.sparse_emb = th.nn.Embedding(num_nodes, emb_size, sparse=True) self.sparse_emb = th.nn.Embedding(num_nodes, emb_size, sparse=True)
nn.init.uniform_(self.sparse_emb.weight, -1.0, 1.0) nn.init.uniform_(self.sparse_emb.weight, -1.0, 1.0)
...@@ -104,21 +135,29 @@ class DistEmb(nn.Module): ...@@ -104,21 +135,29 @@ class DistEmb(nn.Module):
else: else:
return self.sparse_emb(idx).to(self.dev_id) return self.sparse_emb(idx).to(self.dev_id)
def load_embs(standalone, emb_layer, g): def load_embs(standalone, emb_layer, g):
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()), nodes = dgl.distributed.node_split(
g.get_partition_book(), force_even=True) np.arange(g.number_of_nodes()), g.get_partition_book(), force_even=True
)
x = dgl.distributed.DistTensor( x = dgl.distributed.DistTensor(
(g.number_of_nodes(), (
emb_layer.module.emb_size \ g.number_of_nodes(),
if isinstance(emb_layer, th.nn.parallel.DistributedDataParallel) \ emb_layer.module.emb_size
else emb_layer.emb_size), if isinstance(emb_layer, th.nn.parallel.DistributedDataParallel)
th.float32, 'eval_embs', else emb_layer.emb_size,
persistent=True) ),
th.float32,
"eval_embs",
persistent=True,
)
num_nodes = nodes.shape[0] num_nodes = nodes.shape[0]
for i in range((num_nodes + 1023) // 1024): for i in range((num_nodes + 1023) // 1024):
idx = nodes[i * 1024: (i+1) * 1024 \ idx = nodes[
if (i+1) * 1024 < num_nodes \ i * 1024 : (i + 1) * 1024
else num_nodes] if (i + 1) * 1024 < num_nodes
else num_nodes
]
embeds = emb_layer(idx).cpu() embeds = emb_layer(idx).cpu()
x[idx] = embeds x[idx] = embeds
...@@ -127,7 +166,18 @@ def load_embs(standalone, emb_layer, g): ...@@ -127,7 +166,18 @@ def load_embs(standalone, emb_layer, g):
return x return x
def evaluate(standalone, model, emb_layer, g, labels, val_nid, test_nid, batch_size, device):
def evaluate(
standalone,
model,
emb_layer,
g,
labels,
val_nid,
test_nid,
batch_size,
device,
):
""" """
Evaluate the model on the validation set specified by ``val_nid``. Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph. g : The entire graph.
...@@ -144,14 +194,22 @@ def evaluate(standalone, model, emb_layer, g, labels, val_nid, test_nid, batch_s ...@@ -144,14 +194,22 @@ def evaluate(standalone, model, emb_layer, g, labels, val_nid, test_nid, batch_s
pred = model.inference(standalone, g, inputs, batch_size, device) pred = model.inference(standalone, g, inputs, batch_size, device)
model.train() model.train()
emb_layer.train() emb_layer.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid]) return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
pred[test_nid], labels[test_nid]
)
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, n_classes, g = data train_nid, val_nid, test_nid, n_classes, g = data
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(
dgl.distributed.sample_neighbors, device, load_feat=False) g,
[int(fanout) for fanout in args.fan_out.split(",")],
dgl.distributed.sample_neighbors,
device,
load_feat=False,
)
# Create DataLoader for constructing blocks # Create DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
...@@ -159,35 +217,55 @@ def run(args, device, data): ...@@ -159,35 +217,55 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
)
# Define model and optimizer # Define model and optimizer
emb_layer = DistEmb(g.num_nodes(), args.num_hidden, dgl_sparse_emb=args.dgl_sparse, dev_id=device) emb_layer = DistEmb(
model = TransDistSAGE(args.num_hidden, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) g.num_nodes(),
args.num_hidden,
dgl_sparse_emb=args.dgl_sparse,
dev_id=device,
)
model = TransDistSAGE(
args.num_hidden,
args.num_hidden,
n_classes,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
if args.num_gpus == -1: if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model) model = th.nn.parallel.DistributedDataParallel(model)
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = th.nn.parallel.DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
if not args.dgl_sparse: if not args.dgl_sparse:
emb_layer = th.nn.parallel.DistributedDataParallel(emb_layer) emb_layer = th.nn.parallel.DistributedDataParallel(emb_layer)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
if args.dgl_sparse: if args.dgl_sparse:
emb_optimizer = dgl.distributed.optim.SparseAdam([emb_layer.sparse_emb], lr=args.sparse_lr) emb_optimizer = dgl.distributed.optim.SparseAdam(
print('optimize DGL sparse embedding:', emb_layer.sparse_emb) [emb_layer.sparse_emb], lr=args.sparse_lr
)
print("optimize DGL sparse embedding:", emb_layer.sparse_emb)
elif args.standalone: elif args.standalone:
emb_optimizer = th.optim.SparseAdam(list(emb_layer.sparse_emb.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', emb_layer.sparse_emb) list(emb_layer.sparse_emb.parameters()), lr=args.sparse_lr
)
print("optimize Pytorch sparse embedding:", emb_layer.sparse_emb)
else: else:
emb_optimizer = th.optim.SparseAdam(list(emb_layer.module.sparse_emb.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', emb_layer.module.sparse_emb) list(emb_layer.module.sparse_emb.parameters()), lr=args.sparse_lr
)
print("optimize Pytorch sparse embedding:", emb_layer.module.sparse_emb)
train_size = th.sum(g.ndata['train_mask'][0:g.number_of_nodes()]) train_size = th.sum(g.ndata["train_mask"][0 : g.number_of_nodes()])
# Training loop # Training loop
iter_tput = [] iter_tput = []
...@@ -212,7 +290,7 @@ def run(args, device, data): ...@@ -212,7 +290,7 @@ def run(args, device, data):
# The nodes for input lies at the LHS side of the first block. # The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
batch_inputs = blocks[0].srcdata[dgl.NID] batch_inputs = blocks[0].srcdata[dgl.NID]
batch_labels = blocks[-1].dstdata['labels'] batch_labels = blocks[-1].dstdata["labels"]
batch_labels = batch_labels.long() batch_labels = batch_labels.long()
num_seeds += len(blocks[-1].dstdata[dgl.NID]) num_seeds += len(blocks[-1].dstdata[dgl.NID])
...@@ -241,78 +319,146 @@ def run(args, device, data): ...@@ -241,78 +319,146 @@ def run(args, device, data):
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t) iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if step % args.log_every == 0: if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels) acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0 gpu_mem_alloc = (
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format( th.cuda.max_memory_allocated() / 1000000
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:]))) if th.cuda.is_available()
else 0
)
print(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s".format(
g.rank(),
epoch,
step,
loss.item(),
acc.item(),
np.mean(iter_tput[3:]),
gpu_mem_alloc,
np.sum(step_time[-args.log_every :]),
)
)
start = time.time() start = time.time()
toc = time.time() toc = time.time()
print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format( print(
g.rank(), toc - tic, sample_time, forward_time, backward_time, update_time, num_seeds, num_inputs)) "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format(
g.rank(),
toc - tic,
sample_time,
forward_time,
backward_time,
update_time,
num_seeds,
num_inputs,
)
)
epoch += 1 epoch += 1
if epoch % args.eval_every == 0 and epoch != 0: if epoch % args.eval_every == 0 and epoch != 0:
start = time.time() start = time.time()
val_acc, test_acc = evaluate(args.standalone, model.module, emb_layer, g, val_acc, test_acc = evaluate(
g.ndata['labels'], val_nid, test_nid, args.batch_size_eval, device) args.standalone,
print('Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}'.format(g.rank(), val_acc, test_acc, time.time()-start)) model.module,
emb_layer,
g,
g.ndata["labels"],
val_nid,
test_nid,
args.batch_size_eval,
device,
)
print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
g.rank(), val_acc, test_acc, time.time() - start
)
)
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend="gloo")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print("rank:", g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
train_nid = dgl.distributed.node_split(g.ndata['train_mask'], pb, force_even=True) train_nid = dgl.distributed.node_split(
val_nid = dgl.distributed.node_split(g.ndata['val_mask'], pb, force_even=True) g.ndata["train_mask"], pb, force_even=True
test_nid = dgl.distributed.node_split(g.ndata['test_mask'], pb, force_even=True) )
val_nid = dgl.distributed.node_split(
g.ndata["val_mask"], pb, force_even=True
)
test_nid = dgl.distributed.node_split(
g.ndata["test_mask"], pb, force_even=True
)
local_nid = pb.partid2nids(pb.partid).detach().numpy() local_nid = pb.partid2nids(pb.partid).detach().numpy()
print('part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})'.format( print(
g.rank(), len(train_nid), len(np.intersect1d(train_nid.numpy(), local_nid)), "part {}, train: {} (local: {}), val: {} (local: {}), test: {} (local: {})".format(
len(val_nid), len(np.intersect1d(val_nid.numpy(), local_nid)), g.rank(),
len(test_nid), len(np.intersect1d(test_nid.numpy(), local_nid)))) len(train_nid),
len(np.intersect1d(train_nid.numpy(), local_nid)),
len(val_nid),
len(np.intersect1d(val_nid.numpy(), local_nid)),
len(test_nid),
len(np.intersect1d(test_nid.numpy(), local_nid)),
)
)
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device("cpu")
else: else:
device = th.device('cuda:'+str(args.local_rank)) device = th.device("cuda:" + str(args.local_rank))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] labels = g.ndata["labels"][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))])) n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes) print("#labels:", n_classes)
# Pack data # Pack data
data = train_nid, val_nid, test_nid, n_classes, g data = train_nid, val_nid, test_nid, n_classes, g
run(args, device, data) run(args, device, data)
print("parent ends") print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument('--graph_name', type=str, help='graph name') parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument("--id", type=int, help="the partition id")
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument(
parser.add_argument('--part_config', type=str, help='The path to the partition config file') "--ip_config", type=str, help="The file for IP configuration"
parser.add_argument('--num_clients', type=int, help='The number of clients') )
parser.add_argument('--n_classes', type=int, help='the number of classes') parser.add_argument(
parser.add_argument('--num_gpus', type=int, default=-1, "--part_config", type=str, help="The path to the partition config file"
help="the number of GPU device. Use -1 for CPU training") )
parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument("--num_clients", type=int, help="The number of clients")
parser.add_argument('--num_hidden', type=int, default=16) parser.add_argument("--n_classes", type=int, help="the number of classes")
parser.add_argument('--num_layers', type=int, default=2) parser.add_argument(
parser.add_argument('--fan_out', type=str, default='10,25') "--num_gpus",
parser.add_argument('--batch_size', type=int, default=1000) type=int,
parser.add_argument('--batch_size_eval', type=int, default=100000) default=-1,
parser.add_argument('--log_every', type=int, default=20) help="the number of GPU device. Use -1 for CPU training",
parser.add_argument('--eval_every', type=int, default=5) )
parser.add_argument('--lr', type=float, default=0.003) parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument("--num_hidden", type=int, default=16)
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument("--fan_out", type=str, default="10,25")
parser.add_argument("--dgl_sparse", action='store_true', parser.add_argument("--batch_size", type=int, default=1000)
help='Whether to use DGL sparse embedding') parser.add_argument("--batch_size_eval", type=int, default=100000)
parser.add_argument("--sparse_lr", type=float, default=1e-2, parser.add_argument("--log_every", type=int, default=20)
help="sparse lr rate") parser.add_argument("--eval_every", type=int, default=5)
parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument(
"--local_rank", type=int, help="get rank of the process"
)
parser.add_argument(
"--standalone", action="store_true", help="run in the standalone mode"
)
parser.add_argument(
"--dgl_sparse",
action="store_true",
help="Whether to use DGL sparse embedding",
)
parser.add_argument(
"--sparse_lr", type=float, default=1e-2, help="sparse lr rate"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import os import os
os.environ['DGLBACKEND']='pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import argparse
import math
import time
from functools import wraps
from multiprocessing import Process from multiprocessing import Process
import argparse, time, math
import numpy as np import numpy as np
from functools import wraps
import tqdm
import sklearn.linear_model as lm import sklearn.linear_model as lm
import sklearn.metrics as skm import sklearn.metrics as skm
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import tqdm
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.data.utils import load_graphs
from dgl.distributed import DistDataLoader from dgl.distributed import DistDataLoader
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(
in_feats, self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
n_hidden, ):
n_classes,
n_layers,
activation,
dropout):
super().__init__() super().__init__()
self.n_layers = n_layers self.n_layers = n_layers
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
for i in range(1, n_layers - 1): for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.activation = activation self.activation = activation
...@@ -66,7 +66,10 @@ class SAGE(nn.Module): ...@@ -66,7 +66,10 @@ class SAGE(nn.Module):
# on each layer are of course splitted in batches. # on each layer are of course splitted in batches.
# TODO: can we standardize this? # TODO: can we standardize this?
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(
g.number_of_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
)
sampler = dgl.dataloading.MultiLayerNeighborSampler([None]) sampler = dgl.dataloading.MultiLayerNeighborSampler([None])
dataloader = dgl.dataloading.DistNodeDataLoader( dataloader = dgl.dataloading.DistNodeDataLoader(
...@@ -76,7 +79,8 @@ class SAGE(nn.Module): ...@@ -76,7 +79,8 @@ class SAGE(nn.Module):
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0) num_workers=0,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0] block = blocks[0]
...@@ -92,16 +96,22 @@ class SAGE(nn.Module): ...@@ -92,16 +96,22 @@ class SAGE(nn.Module):
x = y x = y
return y return y
class NegativeSampler(object): class NegativeSampler(object):
def __init__(self, g, neg_nseeds): def __init__(self, g, neg_nseeds):
self.neg_nseeds = neg_nseeds self.neg_nseeds = neg_nseeds
def __call__(self, num_samples): def __call__(self, num_samples):
# select local neg nodes as seeds # select local neg nodes as seeds
return self.neg_nseeds[th.randint(self.neg_nseeds.shape[0], (num_samples,))] return self.neg_nseeds[
th.randint(self.neg_nseeds.shape[0], (num_samples,))
]
class NeighborSampler(object): class NeighborSampler(object):
def __init__(self, g, fanouts, neg_nseeds, sample_neighbors, num_negs, remove_edge): def __init__(
self, g, fanouts, neg_nseeds, sample_neighbors, num_negs, remove_edge
):
self.g = g self.g = g
self.fanouts = fanouts self.fanouts = fanouts
self.sample_neighbors = sample_neighbors self.sample_neighbors = sample_neighbors
...@@ -123,21 +133,28 @@ class NeighborSampler(object): ...@@ -123,21 +133,28 @@ class NeighborSampler(object):
# neg_graph contains the correspondence between each head and its negative tails. # neg_graph contains the correspondence between each head and its negative tails.
# Both pos_graph and neg_graph are first constructed with the same node space as # Both pos_graph and neg_graph are first constructed with the same node space as
# the original graph. Then they are compacted together with dgl.compact_graphs. # the original graph. Then they are compacted together with dgl.compact_graphs.
pos_graph = dgl.graph((heads, tails), num_nodes=self.g.number_of_nodes()) pos_graph = dgl.graph(
neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes()) (heads, tails), num_nodes=self.g.number_of_nodes()
)
neg_graph = dgl.graph(
(neg_heads, neg_tails), num_nodes=self.g.number_of_nodes()
)
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph]) pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID] seeds = pos_graph.ndata[dgl.NID]
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True) frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True
)
if self.remove_edge: if self.remove_edge:
# Remove all edges between heads and tails, as well as heads and neg_tails. # Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids( _, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]), th.cat([heads, tails, neg_heads, neg_tails]),
th.cat([tails, heads, neg_tails, neg_heads]), th.cat([tails, heads, neg_tails, neg_heads]),
return_uv=True) return_uv=True,
)
frontier = dgl.remove_edges(frontier, edge_ids) frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
...@@ -148,10 +165,13 @@ class NeighborSampler(object): ...@@ -148,10 +165,13 @@ class NeighborSampler(object):
blocks.insert(0, block) blocks.insert(0, block)
input_nodes = blocks[0].srcdata[dgl.NID] input_nodes = blocks[0].srcdata[dgl.NID]
blocks[0].srcdata['features'] = load_subtensor(self.g, input_nodes, 'cpu') blocks[0].srcdata["features"] = load_subtensor(
self.g, input_nodes, "cpu"
)
# Pre-generate CSR format that it can be used in training directly # Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks return pos_graph, neg_graph, blocks
class PosNeighborSampler(object): class PosNeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors): def __init__(self, g, fanouts, sample_neighbors):
self.g = g self.g = g
...@@ -163,7 +183,9 @@ class PosNeighborSampler(object): ...@@ -163,7 +183,9 @@ class PosNeighborSampler(object):
blocks = [] blocks = []
for fanout in self.fanouts: for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors. # For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True) frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True
)
# Then we compact the frontier into a bipartite graph for message passing. # Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds) block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer. # Obtain the seed nodes for next layer.
...@@ -172,11 +194,14 @@ class PosNeighborSampler(object): ...@@ -172,11 +194,14 @@ class PosNeighborSampler(object):
blocks.insert(0, block) blocks.insert(0, block)
return blocks return blocks
class DistSAGE(SAGE): class DistSAGE(SAGE):
def __init__(self, in_feats, n_hidden, n_classes, n_layers, def __init__(
activation, dropout): self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
super(DistSAGE, self).__init__(in_feats, n_hidden, n_classes, n_layers, ):
activation, dropout) super(DistSAGE, self).__init__(
in_feats, n_hidden, n_classes, n_layers, activation, dropout
)
def inference(self, g, x, batch_size, device): def inference(self, g, x, batch_size, device):
""" """
...@@ -192,31 +217,49 @@ class DistSAGE(SAGE): ...@@ -192,31 +217,49 @@ class DistSAGE(SAGE):
# Therefore, we compute the representation of all nodes layer by layer. The nodes # Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches. # on each layer are of course splitted in batches.
# TODO: can we standardize this? # TODO: can we standardize this?
nodes = dgl.distributed.node_split(np.arange(g.number_of_nodes()), nodes = dgl.distributed.node_split(
g.get_partition_book(), force_even=True) np.arange(g.number_of_nodes()),
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_hidden), th.float32, 'h', g.get_partition_book(),
persistent=True) force_even=True,
)
y = dgl.distributed.DistTensor(
(g.number_of_nodes(), self.n_hidden),
th.float32,
"h",
persistent=True,
)
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
if l == len(self.layers) - 1: if l == len(self.layers) - 1:
y = dgl.distributed.DistTensor((g.number_of_nodes(), self.n_classes), y = dgl.distributed.DistTensor(
th.float32, 'h_last', persistent=True) (g.number_of_nodes(), self.n_classes),
th.float32,
sampler = PosNeighborSampler(g, [-1], dgl.distributed.sample_neighbors) "h_last",
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size)) persistent=True,
)
sampler = PosNeighborSampler(
g, [-1], dgl.distributed.sample_neighbors
)
print(
"|V|={}, eval batch size: {}".format(
g.number_of_nodes(), batch_size
)
)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DistDataLoader( dataloader = DistDataLoader(
dataset=nodes, dataset=nodes,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False) drop_last=False,
)
for blocks in tqdm.tqdm(dataloader): for blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
input_nodes = block.srcdata[dgl.NID] input_nodes = block.srcdata[dgl.NID]
output_nodes = block.dstdata[dgl.NID] output_nodes = block.dstdata[dgl.NID]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[: block.number_of_dst_nodes()]
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = self.activation(h) h = self.activation(h)
...@@ -228,29 +271,34 @@ class DistSAGE(SAGE): ...@@ -228,29 +271,34 @@ class DistSAGE(SAGE):
g.barrier() g.barrier()
return y return y
def load_subtensor(g, input_nodes, device): def load_subtensor(g, input_nodes, device):
""" """
Copys features and labels of a set of nodes onto GPU. Copys features and labels of a set of nodes onto GPU.
""" """
batch_inputs = g.ndata['features'][input_nodes].to(device) batch_inputs = g.ndata["features"][input_nodes].to(device)
return batch_inputs return batch_inputs
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
def forward(self, block_outputs, pos_graph, neg_graph): def forward(self, block_outputs, pos_graph, neg_graph):
with pos_graph.local_scope(): with pos_graph.local_scope():
pos_graph.ndata['h'] = block_outputs pos_graph.ndata["h"] = block_outputs
pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score')) pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
pos_score = pos_graph.edata['score'] pos_score = pos_graph.edata["score"]
with neg_graph.local_scope(): with neg_graph.local_scope():
neg_graph.ndata['h'] = block_outputs neg_graph.ndata["h"] = block_outputs
neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score')) neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
neg_score = neg_graph.edata['score'] neg_score = neg_graph.edata["score"]
score = th.cat([pos_score, neg_score]) score = th.cat([pos_score, neg_score])
label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long() label = th.cat(
[th.ones_like(pos_score), th.zeros_like(neg_score)]
).long()
loss = F.binary_cross_entropy_with_logits(score, label.float()) loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss return loss
def generate_emb(model, g, inputs, batch_size, device): def generate_emb(model, g, inputs, batch_size, device):
""" """
Generate embeddings for each node Generate embeddings for each node
...@@ -265,6 +313,7 @@ def generate_emb(model, g, inputs, batch_size, device): ...@@ -265,6 +313,7 @@ def generate_emb(model, g, inputs, batch_size, device):
return pred return pred
def compute_acc(emb, labels, train_nids, val_nids, test_nids): def compute_acc(emb, labels, train_nids, val_nids, test_nids):
""" """
Compute the accuracy of prediction given the labels. Compute the accuracy of prediction given the labels.
...@@ -288,7 +337,7 @@ def compute_acc(emb, labels, train_nids, val_nids, test_nids): ...@@ -288,7 +337,7 @@ def compute_acc(emb, labels, train_nids, val_nids, test_nids):
labels = labels.cpu().numpy() labels = labels.cpu().numpy()
emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True) emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)
lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000) lr = lm.LogisticRegression(multi_class="multinomial", max_iter=10000)
lr.fit(emb[train_nids], labels[train_nids]) lr.fit(emb[train_nids], labels[train_nids])
pred = lr.predict(emb) pred = lr.predict(emb)
...@@ -296,12 +345,28 @@ def compute_acc(emb, labels, train_nids, val_nids, test_nids): ...@@ -296,12 +345,28 @@ def compute_acc(emb, labels, train_nids, val_nids, test_nids):
test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids]) test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids])
return eval_acc, test_acc return eval_acc, test_acc
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_eids, train_nids, in_feats, g, global_train_nid, global_valid_nid, global_test_nid, labels = data (
train_eids,
train_nids,
in_feats,
g,
global_train_nid,
global_valid_nid,
global_test_nid,
labels,
) = data
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], train_nids, sampler = NeighborSampler(
dgl.distributed.sample_neighbors, args.num_negs, args.remove_edge) g,
[int(fanout) for fanout in args.fan_out.split(",")],
train_nids,
dgl.distributed.sample_neighbors,
args.num_negs,
args.remove_edge,
)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = dgl.distributed.DistDataLoader( dataloader = dgl.distributed.DistDataLoader(
...@@ -309,17 +374,27 @@ def run(args, device, data): ...@@ -309,17 +374,27 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
)
# Define model and optimizer # Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout) model = DistSAGE(
in_feats,
args.num_hidden,
args.num_hidden,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
if args.num_gpus == -1: if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model) model = th.nn.parallel.DistributedDataParallel(model)
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = th.nn.parallel.DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
loss_fcn = CrossEntropyLoss() loss_fcn = CrossEntropyLoss()
loss_fcn = loss_fcn.to(device) loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -358,7 +433,7 @@ def run(args, device, data): ...@@ -358,7 +433,7 @@ def run(args, device, data):
# The nodes for output lies at the RHS side of the last block. # The nodes for output lies at the RHS side of the last block.
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs = blocks[0].srcdata['features'] batch_inputs = blocks[0].srcdata["features"]
copy_time = time.time() copy_time = time.time()
feat_copy_t.append(copy_time - tic_step) feat_copy_t.append(copy_time - tic_step)
...@@ -384,25 +459,53 @@ def run(args, device, data): ...@@ -384,25 +459,53 @@ def run(args, device, data):
iter_tput.append(pos_edges / step_t) iter_tput.append(pos_edges / step_t)
num_seeds += pos_edges num_seeds += pos_edges
if step % args.log_every == 0: if step % args.log_every == 0:
print('[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | time {:.3f} s' \ print(
'| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}'.format( "[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | time {:.3f} s"
g.rank(), epoch, step, loss.item(), np.mean(iter_tput[3:]), np.sum(step_time[-args.log_every:]), "| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}".format(
np.sum(sample_t[-args.log_every:]), np.sum(feat_copy_t[-args.log_every:]), np.sum(forward_t[-args.log_every:]), g.rank(),
np.sum(backward_t[-args.log_every:]), np.sum(update_t[-args.log_every:]))) epoch,
step,
loss.item(),
np.mean(iter_tput[3:]),
np.sum(step_time[-args.log_every :]),
np.sum(sample_t[-args.log_every :]),
np.sum(feat_copy_t[-args.log_every :]),
np.sum(forward_t[-args.log_every :]),
np.sum(backward_t[-args.log_every :]),
np.sum(update_t[-args.log_every :]),
)
)
start = time.time() start = time.time()
print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format( print(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), num_seeds, num_inputs)) "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format(
g.rank(),
np.sum(step_time),
np.sum(sample_t),
np.sum(feat_copy_t),
np.sum(forward_t),
np.sum(backward_t),
np.sum(update_t),
num_seeds,
num_inputs,
)
)
epoch += 1 epoch += 1
# evaluate the embedding using LogisticRegression # evaluate the embedding using LogisticRegression
if args.standalone: if args.standalone:
pred = generate_emb(model,g, g.ndata['features'], args.batch_size_eval, device) pred = generate_emb(
model, g, g.ndata["features"], args.batch_size_eval, device
)
else: else:
pred = generate_emb(model.module, g, g.ndata['features'], args.batch_size_eval, device) pred = generate_emb(
model.module, g, g.ndata["features"], args.batch_size_eval, device
)
if g.rank() == 0: if g.rank() == 0:
eval_acc, test_acc = compute_acc(pred, labels, global_train_nid, global_valid_nid, global_test_nid) eval_acc, test_acc = compute_acc(
print('eval acc {:.4f}; test acc {:.4f}'.format(eval_acc, test_acc)) pred, labels, global_train_nid, global_valid_nid, global_test_nid
)
print("eval acc {:.4f}; test acc {:.4f}".format(eval_acc, test_acc))
# sync for eval and test # sync for eval and test
if not args.standalone: if not args.standalone:
...@@ -413,69 +516,112 @@ def run(args, device, data): ...@@ -413,69 +516,112 @@ def run(args, device, data):
# save features into file # save features into file
if g.rank() == 0: if g.rank() == 0:
th.save(pred, 'emb.pt') th.save(pred, "emb.pt")
else: else:
feat = g.ndata['features'] feat = g.ndata["features"]
th.save(pred, 'emb.pt') th.save(pred, "emb.pt")
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend="gloo")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print("rank:", g.rank())
print('number of edges', g.number_of_edges()) print("number of edges", g.number_of_edges())
train_eids = dgl.distributed.edge_split(th.ones((g.number_of_edges(),), dtype=th.bool), g.get_partition_book(), force_even=True) train_eids = dgl.distributed.edge_split(
train_nids = dgl.distributed.node_split(th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book()) th.ones((g.number_of_edges(),), dtype=th.bool),
global_train_nid = th.LongTensor(np.nonzero(g.ndata['train_mask'][np.arange(g.number_of_nodes())])) g.get_partition_book(),
global_valid_nid = th.LongTensor(np.nonzero(g.ndata['val_mask'][np.arange(g.number_of_nodes())])) force_even=True,
global_test_nid = th.LongTensor(np.nonzero(g.ndata['test_mask'][np.arange(g.number_of_nodes())])) )
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] train_nids = dgl.distributed.node_split(
th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book()
)
global_train_nid = th.LongTensor(
np.nonzero(g.ndata["train_mask"][np.arange(g.number_of_nodes())])
)
global_valid_nid = th.LongTensor(
np.nonzero(g.ndata["val_mask"][np.arange(g.number_of_nodes())])
)
global_test_nid = th.LongTensor(
np.nonzero(g.ndata["test_mask"][np.arange(g.number_of_nodes())])
)
labels = g.ndata["labels"][np.arange(g.number_of_nodes())]
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device("cpu")
else: else:
device = th.device('cuda:'+str(args.local_rank)) device = th.device("cuda:" + str(args.local_rank))
# Pack data # Pack data
in_feats = g.ndata['features'].shape[1] in_feats = g.ndata["features"].shape[1]
global_train_nid = global_train_nid.squeeze() global_train_nid = global_train_nid.squeeze()
global_valid_nid = global_valid_nid.squeeze() global_valid_nid = global_valid_nid.squeeze()
global_test_nid = global_test_nid.squeeze() global_test_nid = global_test_nid.squeeze()
print("number of train {}".format(global_train_nid.shape[0])) print("number of train {}".format(global_train_nid.shape[0]))
print("number of valid {}".format(global_valid_nid.shape[0])) print("number of valid {}".format(global_valid_nid.shape[0]))
print("number of test {}".format(global_test_nid.shape[0])) print("number of test {}".format(global_test_nid.shape[0]))
data = train_eids, train_nids, in_feats, g, global_train_nid, global_valid_nid, global_test_nid, labels data = (
train_eids,
train_nids,
in_feats,
g,
global_train_nid,
global_valid_nid,
global_test_nid,
labels,
)
run(args, device, data) run(args, device, data)
print("parent ends") print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument('--graph_name', type=str, help='graph name') parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument("--id", type=int, help="the partition id")
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument(
parser.add_argument('--part_config', type=str, help='The path to the partition config file') "--ip_config", type=str, help="The file for IP configuration"
parser.add_argument('--n_classes', type=int, help='the number of classes') )
parser.add_argument('--num_gpus', type=int, default=-1, parser.add_argument(
help="the number of GPU device. Use -1 for CPU training") "--part_config", type=str, help="The path to the partition config file"
parser.add_argument('--num_epochs', type=int, default=20) )
parser.add_argument('--num_hidden', type=int, default=16) parser.add_argument("--n_classes", type=int, help="the number of classes")
parser.add_argument('--num-layers', type=int, default=2) parser.add_argument(
parser.add_argument('--fan_out', type=str, default='10,25') "--num_gpus",
parser.add_argument('--batch_size', type=int, default=1000) type=int,
parser.add_argument('--batch_size_eval', type=int, default=100000) default=-1,
parser.add_argument('--log_every', type=int, default=20) help="the number of GPU device. Use -1 for CPU training",
parser.add_argument('--eval_every', type=int, default=5) )
parser.add_argument('--lr', type=float, default=0.003) parser.add_argument("--num_epochs", type=int, default=20)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument("--num_hidden", type=int, default=16)
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument("--fan_out", type=str, default="10,25")
parser.add_argument('--num_negs', type=int, default=1) parser.add_argument("--batch_size", type=int, default=1000)
parser.add_argument('--neg_share', default=False, action='store_true', parser.add_argument("--batch_size_eval", type=int, default=100000)
help="sharing neg nodes for positive nodes") parser.add_argument("--log_every", type=int, default=20)
parser.add_argument('--remove_edge', default=False, action='store_true', parser.add_argument("--eval_every", type=int, default=5)
help="whether to remove edges during sampling") parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument(
"--local_rank", type=int, help="get rank of the process"
)
parser.add_argument(
"--standalone", action="store_true", help="run in the standalone mode"
)
parser.add_argument("--num_negs", type=int, default=1)
parser.add_argument(
"--neg_share",
default=False,
action="store_true",
help="sharing neg nodes for positive nodes",
)
parser.add_argument(
"--remove_edge",
default=False,
action="store_true",
help="whether to remove edges during sampling",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
import os import os
os.environ['DGLBACKEND']='pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import argparse
import math
import time
from functools import wraps
from multiprocessing import Process from multiprocessing import Process
import argparse, time, math
import numpy as np import numpy as np
from functools import wraps
import tqdm
import sklearn.linear_model as lm import sklearn.linear_model as lm
import sklearn.metrics as skm import sklearn.metrics as skm
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import torch as th import torch as th
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import tqdm
from train_dist_transductive import DistEmb, load_embs
from train_dist_unsupervised import (
SAGE,
CrossEntropyLoss,
NeighborSampler,
PosNeighborSampler,
compute_acc,
)
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.data.utils import load_graphs
from dgl.distributed import DistDataLoader from dgl.distributed import DistDataLoader
from train_dist_unsupervised import SAGE, NeighborSampler, PosNeighborSampler, CrossEntropyLoss, compute_acc
from train_dist_transductive import DistEmb, load_embs
def generate_emb(standalone, model, emb_layer, g, batch_size, device): def generate_emb(standalone, model, emb_layer, g, batch_size, device):
""" """
...@@ -42,12 +51,27 @@ def generate_emb(standalone, model, emb_layer, g, batch_size, device): ...@@ -42,12 +51,27 @@ def generate_emb(standalone, model, emb_layer, g, batch_size, device):
g.barrier() g.barrier()
return pred return pred
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_eids, train_nids, g, global_train_nid, global_valid_nid, global_test_nid, labels = data (
train_eids,
train_nids,
g,
global_train_nid,
global_valid_nid,
global_test_nid,
labels,
) = data
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], train_nids, sampler = NeighborSampler(
dgl.distributed.sample_neighbors, args.num_negs, args.remove_edge) g,
[int(fanout) for fanout in args.fan_out.split(",")],
train_nids,
dgl.distributed.sample_neighbors,
args.num_negs,
args.remove_edge,
)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = dgl.distributed.DistDataLoader( dataloader = dgl.distributed.DistDataLoader(
...@@ -55,18 +79,33 @@ def run(args, device, data): ...@@ -55,18 +79,33 @@ def run(args, device, data):
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False) drop_last=False,
)
# Define model and optimizer # Define model and optimizer
emb_layer = DistEmb(g.num_nodes(), args.num_hidden, dgl_sparse_emb=args.dgl_sparse, dev_id=device) emb_layer = DistEmb(
model = SAGE(args.num_hidden, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout) g.num_nodes(),
args.num_hidden,
dgl_sparse_emb=args.dgl_sparse,
dev_id=device,
)
model = SAGE(
args.num_hidden,
args.num_hidden,
args.num_hidden,
args.num_layers,
F.relu,
args.dropout,
)
model = model.to(device) model = model.to(device)
if not args.standalone: if not args.standalone:
if args.num_gpus == -1: if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model) model = th.nn.parallel.DistributedDataParallel(model)
else: else:
dev_id = g.rank() % args.num_gpus dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = th.nn.parallel.DistributedDataParallel(
model, device_ids=[dev_id], output_device=dev_id
)
if not args.dgl_sparse: if not args.dgl_sparse:
emb_layer = th.nn.parallel.DistributedDataParallel(emb_layer) emb_layer = th.nn.parallel.DistributedDataParallel(emb_layer)
loss_fcn = CrossEntropyLoss() loss_fcn = CrossEntropyLoss()
...@@ -74,14 +113,20 @@ def run(args, device, data): ...@@ -74,14 +113,20 @@ def run(args, device, data):
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
if args.dgl_sparse: if args.dgl_sparse:
emb_optimizer = dgl.distributed.optim.SparseAdam([emb_layer.sparse_emb], lr=args.sparse_lr) emb_optimizer = dgl.distributed.optim.SparseAdam(
print('optimize DGL sparse embedding:', emb_layer.sparse_emb) [emb_layer.sparse_emb], lr=args.sparse_lr
)
print("optimize DGL sparse embedding:", emb_layer.sparse_emb)
elif args.standalone: elif args.standalone:
emb_optimizer = th.optim.SparseAdam(list(emb_layer.sparse_emb.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', emb_layer.sparse_emb) list(emb_layer.sparse_emb.parameters()), lr=args.sparse_lr
)
print("optimize Pytorch sparse embedding:", emb_layer.sparse_emb)
else: else:
emb_optimizer = th.optim.SparseAdam(list(emb_layer.module.sparse_emb.parameters()), lr=args.sparse_lr) emb_optimizer = th.optim.SparseAdam(
print('optimize Pytorch sparse embedding:', emb_layer.module.sparse_emb) list(emb_layer.module.sparse_emb.parameters()), lr=args.sparse_lr
)
print("optimize Pytorch sparse embedding:", emb_layer.module.sparse_emb)
# Training loop # Training loop
epoch = 0 epoch = 0
...@@ -146,26 +191,54 @@ def run(args, device, data): ...@@ -146,26 +191,54 @@ def run(args, device, data):
iter_tput.append(pos_edges / step_t) iter_tput.append(pos_edges / step_t)
num_seeds += pos_edges num_seeds += pos_edges
if step % args.log_every == 0: if step % args.log_every == 0:
print('[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | time {:.3f} s' \ print(
'| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}'.format( "[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | time {:.3f} s"
g.rank(), epoch, step, loss.item(), np.mean(iter_tput[3:]), np.sum(step_time[-args.log_every:]), "| sample {:.3f} | copy {:.3f} | forward {:.3f} | backward {:.3f} | update {:.3f}".format(
np.sum(sample_t[-args.log_every:]), np.sum(feat_copy_t[-args.log_every:]), np.sum(forward_t[-args.log_every:]), g.rank(),
np.sum(backward_t[-args.log_every:]), np.sum(update_t[-args.log_every:]))) epoch,
step,
loss.item(),
np.mean(iter_tput[3:]),
np.sum(step_time[-args.log_every :]),
np.sum(sample_t[-args.log_every :]),
np.sum(feat_copy_t[-args.log_every :]),
np.sum(forward_t[-args.log_every :]),
np.sum(backward_t[-args.log_every :]),
np.sum(update_t[-args.log_every :]),
)
)
start = time.time() start = time.time()
print('[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format( print(
g.rank(), np.sum(step_time), np.sum(sample_t), np.sum(feat_copy_t), np.sum(forward_t), np.sum(backward_t), np.sum(update_t), num_seeds, num_inputs)) "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}".format(
g.rank(),
np.sum(step_time),
np.sum(sample_t),
np.sum(feat_copy_t),
np.sum(forward_t),
np.sum(backward_t),
np.sum(update_t),
num_seeds,
num_inputs,
)
)
epoch += 1 epoch += 1
# evaluate the embedding using LogisticRegression # evaluate the embedding using LogisticRegression
if args.standalone: if args.standalone:
pred = generate_emb(True, model, emb_layer, g, args.batch_size_eval, device) pred = generate_emb(
True, model, emb_layer, g, args.batch_size_eval, device
)
else: else:
pred = generate_emb(False, model.module, emb_layer, g, args.batch_size_eval, device) pred = generate_emb(
False, model.module, emb_layer, g, args.batch_size_eval, device
)
if g.rank() == 0: if g.rank() == 0:
eval_acc, test_acc = compute_acc(pred, labels, global_train_nid, global_valid_nid, global_test_nid) eval_acc, test_acc = compute_acc(
print('eval acc {:.4f}; test acc {:.4f}'.format(eval_acc, test_acc)) pred, labels, global_train_nid, global_valid_nid, global_test_nid
)
print("eval acc {:.4f}; test acc {:.4f}".format(eval_acc, test_acc))
# sync for eval and test # sync for eval and test
if not args.standalone: if not args.standalone:
...@@ -176,29 +249,42 @@ def run(args, device, data): ...@@ -176,29 +249,42 @@ def run(args, device, data):
# save features into file # save features into file
if g.rank() == 0: if g.rank() == 0:
th.save(pred, 'emb.pt') th.save(pred, "emb.pt")
else: else:
feat = g.ndata['features'] feat = g.ndata["features"]
th.save(pred, 'emb.pt') th.save(pred, "emb.pt")
def main(args): def main(args):
dgl.distributed.initialize(args.ip_config) dgl.distributed.initialize(args.ip_config)
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend="gloo")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config) g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print("rank:", g.rank())
print('number of edges', g.number_of_edges()) print("number of edges", g.number_of_edges())
train_eids = dgl.distributed.edge_split(th.ones((g.number_of_edges(),), dtype=th.bool), g.get_partition_book(), force_even=True) train_eids = dgl.distributed.edge_split(
train_nids = dgl.distributed.node_split(th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book()) th.ones((g.number_of_edges(),), dtype=th.bool),
global_train_nid = th.LongTensor(np.nonzero(g.ndata['train_mask'][np.arange(g.number_of_nodes())])) g.get_partition_book(),
global_valid_nid = th.LongTensor(np.nonzero(g.ndata['val_mask'][np.arange(g.number_of_nodes())])) force_even=True,
global_test_nid = th.LongTensor(np.nonzero(g.ndata['test_mask'][np.arange(g.number_of_nodes())])) )
labels = g.ndata['labels'][np.arange(g.number_of_nodes())] train_nids = dgl.distributed.node_split(
th.ones((g.number_of_nodes(),), dtype=th.bool), g.get_partition_book()
)
global_train_nid = th.LongTensor(
np.nonzero(g.ndata["train_mask"][np.arange(g.number_of_nodes())])
)
global_valid_nid = th.LongTensor(
np.nonzero(g.ndata["val_mask"][np.arange(g.number_of_nodes())])
)
global_test_nid = th.LongTensor(
np.nonzero(g.ndata["test_mask"][np.arange(g.number_of_nodes())])
)
labels = g.ndata["labels"][np.arange(g.number_of_nodes())]
if args.num_gpus == -1: if args.num_gpus == -1:
device = th.device('cpu') device = th.device("cpu")
else: else:
device = th.device('cuda:'+str(args.local_rank)) device = th.device("cuda:" + str(args.local_rank))
# Pack data # Pack data
global_train_nid = global_train_nid.squeeze() global_train_nid = global_train_nid.squeeze()
...@@ -207,41 +293,74 @@ def main(args): ...@@ -207,41 +293,74 @@ def main(args):
print("number of train {}".format(global_train_nid.shape[0])) print("number of train {}".format(global_train_nid.shape[0]))
print("number of valid {}".format(global_valid_nid.shape[0])) print("number of valid {}".format(global_valid_nid.shape[0]))
print("number of test {}".format(global_test_nid.shape[0])) print("number of test {}".format(global_test_nid.shape[0]))
data = train_eids, train_nids, g, global_train_nid, global_valid_nid, global_test_nid, labels data = (
train_eids,
train_nids,
g,
global_train_nid,
global_valid_nid,
global_test_nid,
labels,
)
run(args, device, data) run(args, device, data)
print("parent ends") print("parent ends")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GCN') if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument('--graph_name', type=str, help='graph name') parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument("--id", type=int, help="the partition id")
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument(
parser.add_argument('--part_config', type=str, help='The path to the partition config file') "--ip_config", type=str, help="The file for IP configuration"
parser.add_argument('--n_classes', type=int, help='the number of classes') )
parser.add_argument('--num_gpus', type=int, default=-1, parser.add_argument(
help="the number of GPU device. Use -1 for CPU training") "--part_config", type=str, help="The path to the partition config file"
parser.add_argument('--num_epochs', type=int, default=5) )
parser.add_argument('--num_hidden', type=int, default=16) parser.add_argument("--n_classes", type=int, help="the number of classes")
parser.add_argument('--num-layers', type=int, default=2) parser.add_argument(
parser.add_argument('--fan_out', type=str, default='10,25') "--num_gpus",
parser.add_argument('--batch_size', type=int, default=1000) type=int,
parser.add_argument('--batch_size_eval', type=int, default=100000) default=-1,
parser.add_argument('--log_every', type=int, default=20) help="the number of GPU device. Use -1 for CPU training",
parser.add_argument('--eval_every', type=int, default=5) )
parser.add_argument('--lr', type=float, default=0.003) parser.add_argument("--num_epochs", type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument("--num_hidden", type=int, default=16)
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument("--fan_out", type=str, default="10,25")
parser.add_argument('--num_negs', type=int, default=1) parser.add_argument("--batch_size", type=int, default=1000)
parser.add_argument('--neg_share', default=False, action='store_true', parser.add_argument("--batch_size_eval", type=int, default=100000)
help="sharing neg nodes for positive nodes") parser.add_argument("--log_every", type=int, default=20)
parser.add_argument('--remove_edge', default=False, action='store_true', parser.add_argument("--eval_every", type=int, default=5)
help="whether to remove edges during sampling") parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dgl_sparse", action='store_true', parser.add_argument("--dropout", type=float, default=0.5)
help='Whether to use DGL sparse embedding') parser.add_argument(
parser.add_argument("--sparse_lr", type=float, default=1e-2, "--local_rank", type=int, help="get rank of the process"
help="sparse lr rate") )
parser.add_argument(
"--standalone", action="store_true", help="run in the standalone mode"
)
parser.add_argument("--num_negs", type=int, default=1)
parser.add_argument(
"--neg_share",
default=False,
action="store_true",
help="sharing neg nodes for positive nodes",
)
parser.add_argument(
"--remove_edge",
default=False,
action="store_true",
help="whether to remove edges during sampling",
)
parser.add_argument(
"--dgl_sparse",
action="store_true",
help="Whether to use DGL sparse embedding",
)
parser.add_argument(
"--sparse_lr", type=float, default=1e-2, help="sparse lr rate"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
...@@ -11,47 +11,49 @@ Copyright (c) 2021 Intel Corporation ...@@ -11,47 +11,49 @@ Copyright (c) 2021 Intel Corporation
""" """
import os import argparse
import sys
import numpy as np
import csv import csv
from statistics import mean import os
import random import random
import sys
import time import time
import argparse from statistics import mean
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import numpy as np
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import load_ogb from load_graph import load_ogb
import dgl import dgl
from dgl.base import DGLError
from dgl.data import load_data from dgl.data import load_data
from dgl.distgnn.partition import partition_graph from dgl.distgnn.partition import partition_graph
from dgl.distgnn.tools import load_proteins from dgl.distgnn.tools import load_proteins
from dgl.base import DGLError
if __name__ == "__main__": if __name__ == "__main__":
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--dataset', type=str, default='cora') argparser.add_argument("--dataset", type=str, default="cora")
argparser.add_argument('--num-parts', type=int, default=2) argparser.add_argument("--num-parts", type=int, default=2)
argparser.add_argument('--out-dir', type=str, default='./') argparser.add_argument("--out-dir", type=str, default="./")
args = argparser.parse_args() args = argparser.parse_args()
dataset = args.dataset dataset = args.dataset
num_community = args.num_parts num_community = args.num_parts
out_dir = 'Libra_result_' + dataset ## "Libra_result_" prefix is mandatory out_dir = "Libra_result_" + dataset ## "Libra_result_" prefix is mandatory
resultdir = os.path.join(args.out_dir, out_dir) resultdir = os.path.join(args.out_dir, out_dir)
print("Input dataset for partitioning: ", dataset) print("Input dataset for partitioning: ", dataset)
if args.dataset == 'ogbn-products': if args.dataset == "ogbn-products":
print("Loading ogbn-products") print("Loading ogbn-products")
G, _ = load_ogb('ogbn-products') G, _ = load_ogb("ogbn-products")
elif args.dataset == 'ogbn-papers100M': elif args.dataset == "ogbn-papers100M":
print("Loading ogbn-papers100M") print("Loading ogbn-papers100M")
G, _ = load_ogb('ogbn-papers100M') G, _ = load_ogb("ogbn-papers100M")
elif args.dataset == 'proteins': elif args.dataset == "proteins":
G = load_proteins('proteins') G = load_proteins("proteins")
elif args.dataset == 'ogbn-arxiv': elif args.dataset == "ogbn-arxiv":
print("Loading ogbn-arxiv") print("Loading ogbn-arxiv")
G, _ = load_ogb('ogbn-arxiv') G, _ = load_ogb("ogbn-arxiv")
else: else:
try: try:
G = load_data(args)[0] G = load_data(args)[0]
......
import dgl import glob
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import dgl.nn.pytorch as dglnn import torchmetrics.functional as MF
import tqdm import tqdm
import glob
import os
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
from torchmetrics import Accuracy
import torchmetrics.functional as MF
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
import dgl
import dgl.nn.pytorch as dglnn
class SAGE(LightningModule): class SAGE(LightningModule):
def __init__(self, in_feats, n_hidden, n_classes): def __init__(self, in_feats, n_hidden, n_classes):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean')) self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
...@@ -39,108 +42,166 @@ class SAGE(LightningModule): ...@@ -39,108 +42,166 @@ class SAGE(LightningModule):
def inference(self, g, device, batch_size, num_workers, buffer_device=None): def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official # The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching. # example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat'] g.ndata["h"] = g.ndata["feat"]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["h"]
)
dataloader = dgl.dataloading.DataLoader( dataloader = dgl.dataloading.DataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device, g,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, torch.arange(g.num_nodes()).to(g.device),
persistent_workers=(num_workers > 0)) sampler,
device=device,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
persistent_workers=(num_workers > 0),
)
if buffer_device is None: if buffer_device is None:
buffer_device = device buffer_device = device
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = torch.zeros( y = torch.zeros(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes, g.num_nodes(),
device=buffer_device) self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h'] x = blocks[0].srcdata["h"]
h = layer(blocks[0], x) h = layer(blocks[0], x)
if l != len(self.layers) - 1: if l != len(self.layers) - 1:
h = F.relu(h) h = F.relu(h)
h = self.dropout(h) h = self.dropout(h)
y[output_nodes] = h.to(buffer_device) y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y g.ndata["h"] = y
return y return y
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata['label'] y = blocks[-1].dstdata["label"]
y_hat = self(blocks, x) y_hat = self(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
self.train_acc(torch.argmax(y_hat, 1), y) self.train_acc(torch.argmax(y_hat, 1), y)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False) self.log(
"train_acc",
self.train_acc,
prog_bar=True,
on_step=True,
on_epoch=False,
)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata['feat'] x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata['label'] y = blocks[-1].dstdata["label"]
y_hat = self(blocks, x) y_hat = self(blocks, x)
self.val_acc(torch.argmax(y_hat, 1), y) self.val_acc(torch.argmax(y_hat, 1), y)
self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True) self.log(
"val_acc",
self.val_acc,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4) optimizer = torch.optim.Adam(
self.parameters(), lr=0.001, weight_decay=5e-4
)
return optimizer return optimizer
class DataModule(LightningDataModule): class DataModule(LightningDataModule):
def __init__(self, graph, train_idx, val_idx, fanouts, batch_size, n_classes): def __init__(
self, graph, train_idx, val_idx, fanouts, batch_size, n_classes
):
super().__init__() super().__init__()
sampler = dgl.dataloading.NeighborSampler( sampler = dgl.dataloading.NeighborSampler(
fanouts, prefetch_node_feats=['feat'], prefetch_labels=['label']) fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"]
)
self.g = graph self.g = graph
self.train_idx, self.val_idx = train_idx, val_idx self.train_idx, self.val_idx = train_idx, val_idx
self.sampler = sampler self.sampler = sampler
self.batch_size = batch_size self.batch_size = batch_size
self.in_feats = graph.ndata['feat'].shape[1] self.in_feats = graph.ndata["feat"].shape[1]
self.n_classes = n_classes self.n_classes = n_classes
def train_dataloader(self): def train_dataloader(self):
return dgl.dataloading.DataLoader( return dgl.dataloading.DataLoader(
self.g, self.train_idx.to('cuda'), self.sampler, self.g,
device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False, self.train_idx.to("cuda"),
self.sampler,
device="cuda",
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
# For CPU sampling, set num_workers to nonzero and use_uva=False # For CPU sampling, set num_workers to nonzero and use_uva=False
# Set use_ddp to False for single GPU. # Set use_ddp to False for single GPU.
num_workers=0, use_uva=True, use_ddp=True) num_workers=0,
use_uva=True,
use_ddp=True,
)
def val_dataloader(self): def val_dataloader(self):
return dgl.dataloading.DataLoader( return dgl.dataloading.DataLoader(
self.g, self.val_idx.to('cuda'), self.sampler, self.g,
device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False, self.val_idx.to("cuda"),
num_workers=0, use_uva=True) self.sampler,
device="cuda",
if __name__ == '__main__': batch_size=self.batch_size,
dataset = DglNodePropPredDataset('ogbn-products') shuffle=True,
drop_last=False,
num_workers=0,
use_uva=True,
)
if __name__ == "__main__":
dataset = DglNodePropPredDataset("ogbn-products")
graph, labels = dataset[0] graph, labels = dataset[0]
graph.ndata['label'] = labels.squeeze() graph.ndata["label"] = labels.squeeze()
graph.create_formats_() graph.create_formats_()
split_idx = dataset.get_idx_split() split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test'] train_idx, val_idx, test_idx = (
datamodule = DataModule(graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes) split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
datamodule = DataModule(
graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes
)
model = SAGE(datamodule.in_feats, 256, datamodule.n_classes) model = SAGE(datamodule.in_feats, 256, datamodule.n_classes)
# Train # Train
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1) checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
# Use this for single GPU # Use this for single GPU
#trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback]) # trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback])
trainer = Trainer(gpus=[0, 1, 2, 3], max_epochs=10, callbacks=[checkpoint_callback], strategy='ddp_spawn') trainer = Trainer(
gpus=[0, 1, 2, 3],
max_epochs=10,
callbacks=[checkpoint_callback],
strategy="ddp_spawn",
)
trainer.fit(model, datamodule=datamodule) trainer.fit(model, datamodule=datamodule)
# Test # Test
dirs = glob.glob('./lightning_logs/*') dirs = glob.glob("./lightning_logs/*")
version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs]) version = max([int(os.path.split(x)[-1].split("_")[-1]) for x in dirs])
logdir = './lightning_logs/version_%d' % version logdir = "./lightning_logs/version_%d" % version
print('Evaluating model in', logdir) print("Evaluating model in", logdir)
ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0] ckpt = glob.glob(os.path.join(logdir, "checkpoints", "*"))[0]
model = SAGE.load_from_checkpoint( model = SAGE.load_from_checkpoint(
checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to('cuda') checkpoint_path=ckpt, hparams_file=os.path.join(logdir, "hparams.yaml")
).to("cuda")
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, 'cuda', 4096, 12, graph.device) pred = model.inference(graph, "cuda", 4096, 12, graph.device)
pred = pred[test_idx] pred = pred[test_idx]
label = graph.ndata['label'][test_idx] label = graph.ndata["label"][test_idx]
acc = MF.accuracy(pred, label) acc = MF.accuracy(pred, label)
print('Test accuracy:', acc) print("Test accuracy:", acc)
import dgl
import torch as th import torch as th
import dgl
def load_reddit(self_loop=True): def load_reddit(self_loop=True):
from dgl.data import RedditDataset from dgl.data import RedditDataset
# load reddit data # load reddit data
data = RedditDataset(self_loop=self_loop) data = RedditDataset(self_loop=self_loop)
g = data[0] g = data[0]
g.ndata['features'] = g.ndata.pop('feat') g.ndata["features"] = g.ndata.pop("feat")
g.ndata['labels'] = g.ndata.pop('label') g.ndata["labels"] = g.ndata.pop("label")
return g, data.num_classes return g, data.num_classes
def load_ogb(name, root='dataset'):
def load_ogb(name, root="dataset"):
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
print('load', name) print("load", name)
data = DglNodePropPredDataset(name=name, root=root) data = DglNodePropPredDataset(name=name, root=root)
print('finish loading', name) print("finish loading", name)
splitted_idx = data.get_idx_split() splitted_idx = data.get_idx_split()
graph, labels = data[0] graph, labels = data[0]
labels = labels[:, 0] labels = labels[:, 0]
graph.ndata['features'] = graph.ndata.pop('feat') graph.ndata["features"] = graph.ndata.pop("feat")
graph.ndata['labels'] = labels graph.ndata["labels"] = labels
in_feats = graph.ndata['features'].shape[1] in_feats = graph.ndata["features"].shape[1]
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))])) num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
# Find the node IDs in the training, validation, and test set. # Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test'] train_nid, val_nid, test_nid = (
splitted_idx["train"],
splitted_idx["valid"],
splitted_idx["test"],
)
train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool) train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True train_mask[train_nid] = True
val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool) val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True val_mask[val_nid] = True
test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool) test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True test_mask[test_nid] = True
graph.ndata['train_mask'] = train_mask graph.ndata["train_mask"] = train_mask
graph.ndata['val_mask'] = val_mask graph.ndata["val_mask"] = val_mask
graph.ndata['test_mask'] = test_mask graph.ndata["test_mask"] = test_mask
print('finish constructing', name) print("finish constructing", name)
return graph, num_labels return graph, num_labels
def inductive_split(g): def inductive_split(g):
"""Split the graph into training graph, validation graph, and test graph by training """Split the graph into training graph, validation graph, and test graph by training
and validation masks. Suitable for inductive models.""" and validation masks. Suitable for inductive models."""
train_g = g.subgraph(g.ndata['train_mask']) train_g = g.subgraph(g.ndata["train_mask"])
val_g = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask']) val_g = g.subgraph(g.ndata["train_mask"] | g.ndata["val_mask"])
test_g = g test_g = g
return train_g, val_g, test_g return train_g, val_g, test_g
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop from dgl import AddSelfLoop
import argparse from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, in_size, hid_size, out_size): def __init__(self, in_size, hid_size, out_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# two-layer GraphSAGE-mean # two-layer GraphSAGE-mean
self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'gcn')) self.layers.append(dglnn.SAGEConv(in_size, hid_size, "gcn"))
self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'gcn')) self.layers.append(dglnn.SAGEConv(hid_size, out_size, "gcn"))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
def forward(self, graph, x): def forward(self, graph, x):
...@@ -24,6 +27,7 @@ class SAGE(nn.Module): ...@@ -24,6 +27,7 @@ class SAGE(nn.Module):
h = self.dropout(h) h = self.dropout(h)
return h return h
def evaluate(g, features, labels, mask, model): def evaluate(g, features, labels, mask, model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -34,6 +38,7 @@ def evaluate(g, features, labels, mask, model): ...@@ -34,6 +38,7 @@ def evaluate(g, features, labels, mask, model):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model): def train(g, features, labels, masks, model):
# define train/val samples, loss function and optimizer # define train/val samples, loss function and optimizer
train_mask, val_mask = masks train_mask, val_mask = masks
...@@ -49,32 +54,42 @@ def train(g, features, labels, masks, model): ...@@ -49,32 +54,42 @@ def train(g, features, labels, masks, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = evaluate(g, features, labels, val_mask, model) acc = evaluate(g, features, labels, val_mask, model)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} " print(
. format(epoch, loss.item(), acc)) "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GraphSAGE') parser = argparse.ArgumentParser(description="GraphSAGE")
parser.add_argument("--dataset", type=str, default="cora", parser.add_argument(
help="Dataset name ('cora', 'citeseer', 'pubmed')") "--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed')",
)
args = parser.parse_args() args = parser.parse_args()
print(f'Training with DGL built-in GraphSage module') print(f"Training with DGL built-in GraphSage module")
# load and preprocess dataset # load and preprocess dataset
transform = AddSelfLoop() # by default, it will first remove self-loops to prevent duplication transform = (
if args.dataset == 'cora': AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform) data = CoraGraphDataset(transform=transform)
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform) data = CiteseerGraphDataset(transform=transform)
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform) data = PubmedGraphDataset(transform=transform)
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device) g = g.int().to(device)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
masks = g.ndata['train_mask'], g.ndata['val_mask'] masks = g.ndata["train_mask"], g.ndata["val_mask"]
# create GraphSAGE model # create GraphSAGE model
in_size = features.shape[1] in_size = features.shape[1]
...@@ -82,10 +97,10 @@ if __name__ == '__main__': ...@@ -82,10 +97,10 @@ if __name__ == '__main__':
model = SAGE(in_size, 16, out_size).to(device) model = SAGE(in_size, 16, out_size).to(device)
# model training # model training
print('Training...') print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
# test the model # test the model
print('Testing...') print("Testing...")
acc = evaluate(g, features, labels, g.ndata['test_mask'], model) acc = evaluate(g, features, labels, g.ndata["test_mask"], model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
CONFIG = {
CONFIG={ "ppi_n": {
'ppi_n': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0, 'edge_budget': 4000, 'length': 2, "dataset": "ppi",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000, "dropout": 0,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 4000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 6000,
"num_subg": 50,
"num_roots": 3000,
"sampler": "node",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"ppi_e": {
'ppi_e': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2, "dataset": "ppi",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 4000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 6000,
"num_subg": 50,
"num_roots": 3000,
"sampler": "edge",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"ppi_rw": {
'ppi_rw': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2, "dataset": "ppi",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 4000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 6000,
"num_subg": 50,
"num_roots": 3000,
"sampler": "rw",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"flickr_n": {
'flickr_n': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2, "dataset": "flickr",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.2,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 256,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 25,
"num_roots": 6000,
"sampler": "node",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": False,
}, },
"flickr_e": {
'flickr_e': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2, "dataset": "flickr",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.2,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 256,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 25,
"num_roots": 6000,
"sampler": "edge",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": False,
}, },
"flickr_rw": {
'flickr_rw': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2, "dataset": "flickr",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.2,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 50,
"n_hidden": 256,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 25,
"num_roots": 6000,
"sampler": "rw",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 0,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": False,
}, },
"reddit_n": {
'reddit_n': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2, "dataset": "reddit",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 20, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 4000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 20,
"n_hidden": 128,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 50,
"num_roots": 3000,
"sampler": "node",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"reddit_e": {
'reddit_e': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 6000, 'length': 2, "dataset": "reddit",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 20, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 20,
"n_hidden": 128,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 50,
"num_roots": 3000,
"sampler": "edge",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"reddit_rw": {
'reddit_rw': "aggr": "concat",
{ "arch": "1-0-1-0",
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 6000, 'length': 4, "dataset": "reddit",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 200, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 4,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 10,
"n_hidden": 128,
"no_batch_norm": False,
"node_budget": 8000,
"num_subg": 50,
"num_roots": 200,
"sampler": "rw",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"yelp_n": {
'yelp_n': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 6000, 'length': 4, "dataset": "yelp",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 200, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 6000,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 4,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 10,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 5000,
"num_subg": 50,
"num_roots": 200,
"sampler": "node",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"yelp_e": {
'yelp_e': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 2500, 'length': 4, "dataset": "yelp",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 200, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 2500,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 4,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 10,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 5000,
"num_subg": 50,
"num_roots": 200,
"sampler": "edge",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"yelp_rw": {
'yelp_rw': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 2500, 'length': 2, "dataset": "yelp",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 1250, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8, "edge_budget": 2500,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 10,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 5000,
"num_subg": 50,
"num_roots": 1250,
"sampler": "rw",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 8,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"amazon_n": {
'amazon_n': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2500, 'length': 4, "dataset": "amazon",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 5, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 4500, "dropout": 0.1,
'num_subg': 50, 'num_roots': 200, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 4, "edge_budget": 2500,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True "length": 4,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 5,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 4500,
"num_subg": 50,
"num_roots": 200,
"sampler": "node",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 4,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
"amazon_e": {
'amazon_e': "aggr": "concat",
{ "arch": "1-1-0",
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2000, 'gpu': 0,'length': 4, "dataset": "amazon",
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000, "dropout": 0.1,
'num_subg': 50, 'num_roots': 200, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 20, "edge_budget": 2000,
'num_subg_sampler': 5000, 'batch_size_sampler': 50, 'num_workers': 26, 'full': True "gpu": 0,
"length": 4,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 10,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 5000,
"num_subg": 50,
"num_roots": 200,
"sampler": "edge",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 20,
"num_subg_sampler": 5000,
"batch_size_sampler": 50,
"num_workers": 26,
"full": True,
},
"amazon_rw": {
"aggr": "concat",
"arch": "1-1-0",
"dataset": "amazon",
"dropout": 0.1,
"edge_budget": 2500,
"gpu": 0,
"length": 2,
"log_dir": "none",
"lr": 0.01,
"n_epochs": 5,
"n_hidden": 512,
"no_batch_norm": False,
"node_budget": 5000,
"num_subg": 50,
"num_roots": 1500,
"sampler": "rw",
"use_val": True,
"val_every": 1,
"num_workers_sampler": 4,
"num_subg_sampler": 10000,
"batch_size_sampler": 200,
"num_workers": 8,
"full": True,
}, },
'amazon_rw':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2500, 'gpu': 0,'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 5, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 1500, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 4,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
}
} }
import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch as th
import dgl.function as fn import dgl.function as fn
class GCNLayer(nn.Module): class GCNLayer(nn.Module):
def __init__(self, in_dim, out_dim, order=1, act=None, def __init__(
dropout=0, batch_norm=False, aggr="concat"): self,
in_dim,
out_dim,
order=1,
act=None,
dropout=0,
batch_norm=False,
aggr="concat",
):
super(GCNLayer, self).__init__() super(GCNLayer, self).__init__()
self.lins = nn.ModuleList() self.lins = nn.ModuleList()
self.bias = nn.ParameterList() self.bias = nn.ParameterList()
...@@ -32,7 +41,9 @@ class GCNLayer(nn.Module): ...@@ -32,7 +41,9 @@ class GCNLayer(nn.Module):
for lin in self.lins: for lin in self.lins:
nn.init.xavier_normal_(lin.weight) nn.init.xavier_normal_(lin.weight)
def feat_trans(self, features, idx): # linear transformation + activation + batch normalization def feat_trans(
self, features, idx
): # linear transformation + activation + batch normalization
h = self.lins[idx](features) + self.bias[idx] h = self.lins[idx](features) + self.bias[idx]
if self.act is not None: if self.act is not None:
...@@ -50,14 +61,17 @@ class GCNLayer(nn.Module): ...@@ -50,14 +61,17 @@ class GCNLayer(nn.Module):
h_in = self.dropout(features) h_in = self.dropout(features)
h_hop = [h_in] h_hop = [h_in]
D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm'] D_norm = (
g.ndata["train_D_norm"]
if "train_D_norm" in g.ndata
else g.ndata["full_D_norm"]
)
for _ in range(self.order): # forward propagation for _ in range(self.order): # forward propagation
g.ndata['h'] = h_hop[-1] g.ndata["h"] = h_hop[-1]
if 'w' not in g.edata: if "w" not in g.edata:
g.edata['w'] = th.ones((g.num_edges(), )).to(features.device) g.edata["w"] = th.ones((g.num_edges(),)).to(features.device)
g.update_all(fn.u_mul_e('h', 'w', 'm'), g.update_all(fn.u_mul_e("h", "w", "m"), fn.sum("m", "h"))
fn.sum('m', 'h')) h = g.ndata.pop("h")
h = g.ndata.pop('h')
h = h * D_norm h = h * D_norm
h_hop.append(h) h_hop.append(h)
...@@ -75,30 +89,73 @@ class GCNLayer(nn.Module): ...@@ -75,30 +89,73 @@ class GCNLayer(nn.Module):
class GCNNet(nn.Module): class GCNNet(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, arch="1-1-0", def __init__(
act=F.relu, dropout=0, batch_norm=False, aggr="concat"): self,
in_dim,
hid_dim,
out_dim,
arch="1-1-0",
act=F.relu,
dropout=0,
batch_norm=False,
aggr="concat",
):
super(GCNNet, self).__init__() super(GCNNet, self).__init__()
self.gcn = nn.ModuleList() self.gcn = nn.ModuleList()
orders = list(map(int, arch.split('-'))) orders = list(map(int, arch.split("-")))
self.gcn.append(GCNLayer(in_dim=in_dim, out_dim=hid_dim, order=orders[0], self.gcn.append(
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) GCNLayer(
in_dim=in_dim,
out_dim=hid_dim,
order=orders[0],
act=act,
dropout=dropout,
batch_norm=batch_norm,
aggr=aggr,
)
)
pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim
for i in range(1, len(orders)-1): for i in range(1, len(orders) - 1):
self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[i], self.gcn.append(
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) GCNLayer(
in_dim=pre_out,
out_dim=hid_dim,
order=orders[i],
act=act,
dropout=dropout,
batch_norm=batch_norm,
aggr=aggr,
)
)
pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim
self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[-1], self.gcn.append(
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) GCNLayer(
in_dim=pre_out,
out_dim=hid_dim,
order=orders[-1],
act=act,
dropout=dropout,
batch_norm=batch_norm,
aggr=aggr,
)
)
pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim
self.out_layer = GCNLayer(in_dim=pre_out, out_dim=out_dim, order=0, self.out_layer = GCNLayer(
act=None, dropout=dropout, batch_norm=False, aggr=aggr) in_dim=pre_out,
out_dim=out_dim,
order=0,
act=None,
dropout=dropout,
batch_norm=False,
aggr=aggr,
)
def forward(self, graph): def forward(self, graph):
h = graph.ndata['feat'] h = graph.ndata["feat"]
for layer in self.gcn: for layer in self.gcn:
h = layer(graph, h) h = layer(graph, h)
...@@ -107,4 +164,3 @@ class GCNNet(nn.Module): ...@@ -107,4 +164,3 @@ class GCNNet(nn.Module):
h = self.out_layer(graph, h) h = self.out_layer(graph, h)
return h return h
import math
import os import os
import random
import time import time
import math
import numpy as np
import scipy
import torch as th import torch as th
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import random
import numpy as np
import dgl.function as fn
import dgl import dgl
from dgl.sampling import random_walk, pack_traces import dgl.function as fn
import scipy from dgl.sampling import pack_traces, random_walk
# The base class of sampler # The base class of sampler
...@@ -67,8 +69,19 @@ class SAINTSampler: ...@@ -67,8 +69,19 @@ class SAINTSampler:
`batch_size` of `DataLoader`, that is, `batch_size_sampler` is not related to how sampler works in training procedure. `batch_size` of `DataLoader`, that is, `batch_size_sampler` is not related to how sampler works in training procedure.
""" """
def __init__(self, node_budget, dn, g, train_nid, num_workers_sampler, num_subg_sampler=10000, def __init__(
batch_size_sampler=200, online=True, num_subg=50, full=True): self,
node_budget,
dn,
g,
train_nid,
num_workers_sampler,
num_subg_sampler=10000,
batch_size_sampler=200,
online=True,
num_subg=50,
full=True,
):
self.g = g.cpu() self.g = g.cpu()
self.node_budget = node_budget self.node_budget = node_budget
self.train_g: dgl.graph = g.subgraph(train_nid) self.train_g: dgl.graph = g.subgraph(train_nid)
...@@ -83,22 +96,30 @@ class SAINTSampler: ...@@ -83,22 +96,30 @@ class SAINTSampler:
self.online = online self.online = online
self.full = full self.full = full
assert self.num_subg_sampler >= self.batch_size_sampler, "num_subg_sampler should be greater than batch_size_sampler" assert (
self.num_subg_sampler >= self.batch_size_sampler
), "num_subg_sampler should be greater than batch_size_sampler"
graph_fn, norm_fn = self.__generate_fn__() graph_fn, norm_fn = self.__generate_fn__()
if os.path.exists(graph_fn): if os.path.exists(graph_fn):
self.subgraphs = np.load(graph_fn, allow_pickle=True) self.subgraphs = np.load(graph_fn, allow_pickle=True)
aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True) aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True)
else: else:
os.makedirs('./subgraphs/', exist_ok=True) os.makedirs("./subgraphs/", exist_ok=True)
self.subgraphs = [] self.subgraphs = []
self.N, sampled_nodes = 0, 0 self.N, sampled_nodes = 0, 0
# N: the number of pre-sampled subgraphs # N: the number of pre-sampled subgraphs
# Employ parallelism to speed up the sampling procedure # Employ parallelism to speed up the sampling procedure
loader = DataLoader(self, batch_size=self.batch_size_sampler, shuffle=True, loader = DataLoader(
num_workers=self.num_workers_sampler, collate_fn=self.__collate_fn__, drop_last=False) self,
batch_size=self.batch_size_sampler,
shuffle=True,
num_workers=self.num_workers_sampler,
collate_fn=self.__collate_fn__,
drop_last=False,
)
t = time.perf_counter() t = time.perf_counter()
for num_nodes, subgraphs_nids, subgraphs_eids in loader: for num_nodes, subgraphs_nids, subgraphs_eids in loader:
...@@ -106,12 +127,16 @@ class SAINTSampler: ...@@ -106,12 +127,16 @@ class SAINTSampler:
self.subgraphs.extend(subgraphs_nids) self.subgraphs.extend(subgraphs_nids)
sampled_nodes += num_nodes sampled_nodes += num_nodes
_subgraphs, _node_counts = np.unique(np.concatenate(subgraphs_nids), return_counts=True) _subgraphs, _node_counts = np.unique(
np.concatenate(subgraphs_nids), return_counts=True
)
sampled_nodes_idx = th.from_numpy(_subgraphs) sampled_nodes_idx = th.from_numpy(_subgraphs)
_node_counts = th.from_numpy(_node_counts) _node_counts = th.from_numpy(_node_counts)
self.node_counter[sampled_nodes_idx] += _node_counts self.node_counter[sampled_nodes_idx] += _node_counts
_subgraphs_eids, _edge_counts = np.unique(np.concatenate(subgraphs_eids), return_counts=True) _subgraphs_eids, _edge_counts = np.unique(
np.concatenate(subgraphs_eids), return_counts=True
)
sampled_edges_idx = th.from_numpy(_subgraphs_eids) sampled_edges_idx = th.from_numpy(_subgraphs_eids)
_edge_counts = th.from_numpy(_edge_counts) _edge_counts = th.from_numpy(_edge_counts)
self.edge_counter[sampled_edges_idx] += _edge_counts self.edge_counter[sampled_edges_idx] += _edge_counts
...@@ -120,16 +145,16 @@ class SAINTSampler: ...@@ -120,16 +145,16 @@ class SAINTSampler:
if sampled_nodes > self.train_g.num_nodes() * num_subg: if sampled_nodes > self.train_g.num_nodes() * num_subg:
break break
print(f'Sampling time: [{time.perf_counter() - t:.2f}s]') print(f"Sampling time: [{time.perf_counter() - t:.2f}s]")
np.save(graph_fn, self.subgraphs) np.save(graph_fn, self.subgraphs)
t = time.perf_counter() t = time.perf_counter()
aggr_norm, loss_norm = self.__compute_norm__() aggr_norm, loss_norm = self.__compute_norm__()
print(f'Normalization time: [{time.perf_counter() - t:.2f}s]') print(f"Normalization time: [{time.perf_counter() - t:.2f}s]")
np.save(norm_fn, (aggr_norm, loss_norm)) np.save(norm_fn, (aggr_norm, loss_norm))
self.train_g.ndata['l_n'] = th.Tensor(loss_norm) self.train_g.ndata["l_n"] = th.Tensor(loss_norm)
self.train_g.edata['w'] = th.Tensor(aggr_norm) self.train_g.edata["w"] = th.Tensor(aggr_norm)
self.__compute_degree_norm() # basically normalizing adjacent matrix self.__compute_degree_norm() # basically normalizing adjacent matrix
random.shuffle(self.subgraphs) random.shuffle(self.subgraphs)
...@@ -159,11 +184,15 @@ class SAINTSampler: ...@@ -159,11 +184,15 @@ class SAINTSampler:
else: else:
subgraph_nids = self.__sample__() subgraph_nids = self.__sample__()
num_nodes = len(subgraph_nids) num_nodes = len(subgraph_nids)
subgraph_eids = dgl.node_subgraph(self.train_g, subgraph_nids).edata[dgl.EID] subgraph_eids = dgl.node_subgraph(
self.train_g, subgraph_nids
).edata[dgl.EID]
return num_nodes, subgraph_nids, subgraph_eids return num_nodes, subgraph_nids, subgraph_eids
def __collate_fn__(self, batch): def __collate_fn__(self, batch):
if self.train: # sample only one graph each epoch, batch_size in training phase in 1 if (
self.train
): # sample only one graph each epoch, batch_size in training phase in 1
return batch[0] return batch[0]
else: else:
sum_num_nodes = 0 sum_num_nodes = 0
...@@ -191,20 +220,24 @@ class SAINTSampler: ...@@ -191,20 +220,24 @@ class SAINTSampler:
loss_norm = self.N / self.node_counter / self.train_g.num_nodes() loss_norm = self.N / self.node_counter / self.train_g.num_nodes()
self.train_g.ndata['n_c'] = self.node_counter self.train_g.ndata["n_c"] = self.node_counter
self.train_g.edata['e_c'] = self.edge_counter self.train_g.edata["e_c"] = self.edge_counter
self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n')) self.train_g.apply_edges(fn.v_div_e("n_c", "e_c", "a_n"))
aggr_norm = self.train_g.edata.pop('a_n') aggr_norm = self.train_g.edata.pop("a_n")
self.train_g.ndata.pop('n_c') self.train_g.ndata.pop("n_c")
self.train_g.edata.pop('e_c') self.train_g.edata.pop("e_c")
return aggr_norm.numpy(), loss_norm.numpy() return aggr_norm.numpy(), loss_norm.numpy()
def __compute_degree_norm(self): def __compute_degree_norm(self):
self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1) self.train_g.ndata[
self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1) "train_D_norm"
] = 1.0 / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)
self.g.ndata["full_D_norm"] = 1.0 / self.g.in_degrees().float().clamp(
min=1
).unsqueeze(1)
def __sample__(self): def __sample__(self):
raise NotImplementedError raise NotImplementedError
...@@ -224,20 +257,30 @@ class SAINTNodeSampler(SAINTSampler): ...@@ -224,20 +257,30 @@ class SAINTNodeSampler(SAINTSampler):
def __init__(self, node_budget, **kwargs): def __init__(self, node_budget, **kwargs):
self.node_budget = node_budget self.node_budget = node_budget
super(SAINTNodeSampler, self).__init__(node_budget=node_budget, **kwargs) super(SAINTNodeSampler, self).__init__(
node_budget=node_budget, **kwargs
)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget, graph_fn = os.path.join(
self.num_subg)) "./subgraphs/{}_Node_{}_{}.npy".format(
norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget, self.dn, self.node_budget, self.num_subg
self.num_subg)) )
)
norm_fn = os.path.join(
"./subgraphs/{}_Node_{}_{}_norm.npy".format(
self.dn, self.node_budget, self.num_subg
)
)
return graph_fn, norm_fn return graph_fn, norm_fn
def __sample__(self): def __sample__(self):
if self.prob is None: if self.prob is None:
self.prob = self.train_g.in_degrees().float().clamp(min=1) self.prob = self.train_g.in_degrees().float().clamp(min=1)
sampled_nodes = th.multinomial(self.prob, num_samples=self.node_budget, replacement=True).unique() sampled_nodes = th.multinomial(
self.prob, num_samples=self.node_budget, replacement=True
).unique()
return sampled_nodes.numpy() return sampled_nodes.numpy()
...@@ -257,14 +300,21 @@ class SAINTEdgeSampler(SAINTSampler): ...@@ -257,14 +300,21 @@ class SAINTEdgeSampler(SAINTSampler):
self.edge_budget = edge_budget self.edge_budget = edge_budget
self.rng = np.random.default_rng() self.rng = np.random.default_rng()
super(SAINTEdgeSampler, self).__init__(
super(SAINTEdgeSampler, self).__init__(node_budget=edge_budget*2, **kwargs) node_budget=edge_budget * 2, **kwargs
)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget, graph_fn = os.path.join(
self.num_subg)) "./subgraphs/{}_Edge_{}_{}.npy".format(
norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget, self.dn, self.edge_budget, self.num_subg
self.num_subg)) )
)
norm_fn = os.path.join(
"./subgraphs/{}_Edge_{}_{}_norm.npy".format(
self.dn, self.edge_budget, self.num_subg
)
)
return graph_fn, norm_fn return graph_fn, norm_fn
# TODO: only sample half edges, then add another half edges # TODO: only sample half edges, then add another half edges
...@@ -272,10 +322,15 @@ class SAINTEdgeSampler(SAINTSampler): ...@@ -272,10 +322,15 @@ class SAINTEdgeSampler(SAINTSampler):
def __sample__(self): def __sample__(self):
if self.prob is None: if self.prob is None:
src, dst = self.train_g.edges() src, dst = self.train_g.edges()
src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1), \ src_degrees, dst_degrees = self.train_g.in_degrees(
self.train_g.in_degrees(dst).float().clamp(min=1) src
prob_mat = 1. / src_degrees + 1. / dst_degrees ).float().clamp(min=1), self.train_g.in_degrees(dst).float().clamp(
prob_mat = scipy.sparse.csr_matrix((prob_mat.numpy(), (src.numpy(), dst.numpy()))) min=1
)
prob_mat = 1.0 / src_degrees + 1.0 / dst_degrees
prob_mat = scipy.sparse.csr_matrix(
(prob_mat.numpy(), (src.numpy(), dst.numpy()))
)
# The edge probability here only contains that of edges in upper triangle adjacency matrix # The edge probability here only contains that of edges in upper triangle adjacency matrix
# Because we assume the graph is undirected, that is, the adjacency matrix is symmetric. We only need # Because we assume the graph is undirected, that is, the adjacency matrix is symmetric. We only need
# to consider half of edges in the graph. # to consider half of edges in the graph.
...@@ -284,9 +339,16 @@ class SAINTEdgeSampler(SAINTSampler): ...@@ -284,9 +339,16 @@ class SAINTEdgeSampler(SAINTSampler):
self.adj_nodes = np.stack(prob_mat.nonzero(), axis=1) self.adj_nodes = np.stack(prob_mat.nonzero(), axis=1)
sampled_edges = np.unique( sampled_edges = np.unique(
dgl.random.choice(len(self.prob), size=self.edge_budget, prob=self.prob, replace=False) dgl.random.choice(
) len(self.prob),
sampled_nodes = np.unique(self.adj_nodes[sampled_edges].flatten()).astype('long') size=self.edge_budget,
prob=self.prob,
replace=False,
)
)
sampled_nodes = np.unique(
self.adj_nodes[sampled_edges].flatten()
).astype("long")
return sampled_nodes return sampled_nodes
...@@ -307,18 +369,30 @@ class SAINTRandomWalkSampler(SAINTSampler): ...@@ -307,18 +369,30 @@ class SAINTRandomWalkSampler(SAINTSampler):
def __init__(self, num_roots, length, **kwargs): def __init__(self, num_roots, length, **kwargs):
self.num_roots, self.length = num_roots, length self.num_roots, self.length = num_roots, length
super(SAINTRandomWalkSampler, self).__init__(node_budget=num_roots * length, **kwargs) super(SAINTRandomWalkSampler, self).__init__(
node_budget=num_roots * length, **kwargs
)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots, graph_fn = os.path.join(
self.length, self.num_subg)) "./subgraphs/{}_RW_{}_{}_{}.npy".format(
norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots, self.dn, self.num_roots, self.length, self.num_subg
self.length, self.num_subg)) )
)
norm_fn = os.path.join(
"./subgraphs/{}_RW_{}_{}_{}_norm.npy".format(
self.dn, self.num_roots, self.length, self.num_subg
)
)
return graph_fn, norm_fn return graph_fn, norm_fn
def __sample__(self): def __sample__(self):
sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots,)) sampled_roots = th.randint(
traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) 0, self.train_g.num_nodes(), (self.num_roots,)
)
traces, types = random_walk(
self.train_g, nodes=sampled_roots, length=self.length
)
sampled_nodes, _, _, _ = pack_traces(traces, types) sampled_nodes, _, _, _ = pack_traces(traces, types)
sampled_nodes = sampled_nodes.unique() sampled_nodes = sampled_nodes.unique()
return sampled_nodes.numpy() return sampled_nodes.numpy()
import argparse import argparse
import os import os
import time import time
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader
from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler
from config import CONFIG from config import CONFIG
from modules import GCNNet from modules import GCNNet
from utils import Logger, evaluate, save_log_dir, load_data, calc_f1 from sampler import SAINTEdgeSampler, SAINTNodeSampler, SAINTRandomWalkSampler
import warnings from torch.utils.data import DataLoader
from utils import Logger, calc_f1, evaluate, load_data, save_log_dir
def main(args, task): def main(args, task):
warnings.filterwarnings('ignore') warnings.filterwarnings("ignore")
multilabel_data = {'ppi', 'yelp', 'amazon'} multilabel_data = {"ppi", "yelp", "amazon"}
multilabel = args.dataset in multilabel_data multilabel = args.dataset in multilabel_data
# This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly # This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly
...@@ -20,7 +22,7 @@ def main(args, task): ...@@ -20,7 +22,7 @@ def main(args, task):
# 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase # 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase
# 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase # 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase
# We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu # We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu
if args.dataset in ['amazon']: if args.dataset in ["amazon"]:
cpu_flag = True cpu_flag = True
else: else:
cpu_flag = False cpu_flag = False
...@@ -28,14 +30,14 @@ def main(args, task): ...@@ -28,14 +30,14 @@ def main(args, task):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args, multilabel) data = load_data(args, multilabel)
g = data.g g = data.g
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_nid = data.train_nid train_nid = data.train_nid
in_feats = g.ndata['feat'].shape[1] in_feats = g.ndata["feat"].shape[1]
n_classes = data.num_classes n_classes = data.num_classes
n_nodes = g.num_nodes() n_nodes = g.num_nodes()
n_edges = g.num_edges() n_edges = g.num_edges()
...@@ -44,23 +46,35 @@ def main(args, task): ...@@ -44,23 +46,35 @@ def main(args, task):
n_val_samples = val_mask.int().sum().item() n_val_samples = val_mask.int().sum().item()
n_test_samples = test_mask.int().sum().item() n_test_samples = test_mask.int().sum().item()
print("""----Data statistics------' print(
"""----Data statistics------'
#Nodes %d #Nodes %d
#Edges %d #Edges %d
#Classes/Labels (multi binary labels) %d #Classes/Labels (multi binary labels) %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_nodes, n_edges, n_classes, % (
n_train_samples, n_nodes,
n_val_samples, n_edges,
n_test_samples)) n_classes,
n_train_samples,
n_val_samples,
n_test_samples,
)
)
# load sampler # load sampler
kwargs = { kwargs = {
'dn': args.dataset, 'g': g, 'train_nid': train_nid, 'num_workers_sampler': args.num_workers_sampler, "dn": args.dataset,
'num_subg_sampler': args.num_subg_sampler, 'batch_size_sampler': args.batch_size_sampler, "g": g,
'online': args.online, 'num_subg': args.num_subg, 'full': args.full "train_nid": train_nid,
"num_workers_sampler": args.num_workers_sampler,
"num_subg_sampler": args.num_subg_sampler,
"batch_size_sampler": args.batch_size_sampler,
"online": args.online,
"num_subg": args.num_subg,
"full": args.full,
} }
if args.sampler == "node": if args.sampler == "node":
...@@ -68,11 +82,19 @@ def main(args, task): ...@@ -68,11 +82,19 @@ def main(args, task):
elif args.sampler == "edge": elif args.sampler == "edge":
saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs) saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)
elif args.sampler == "rw": elif args.sampler == "rw":
saint_sampler = SAINTRandomWalkSampler(args.num_roots, args.length, **kwargs) saint_sampler = SAINTRandomWalkSampler(
args.num_roots, args.length, **kwargs
)
else: else:
raise NotImplementedError raise NotImplementedError
loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1, loader = DataLoader(
shuffle=True, num_workers=args.num_workers, drop_last=False) saint_sampler,
collate_fn=saint_sampler.__collate_fn__,
batch_size=1,
shuffle=True,
num_workers=args.num_workers,
drop_last=False,
)
# set device for dataset tensors # set device for dataset tensors
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
...@@ -82,10 +104,10 @@ def main(args, task): ...@@ -82,10 +104,10 @@ def main(args, task):
val_mask = val_mask.cuda() val_mask = val_mask.cuda()
test_mask = test_mask.cuda() test_mask = test_mask.cuda()
if not cpu_flag: if not cpu_flag:
g = g.to('cuda:{}'.format(args.gpu)) g = g.to("cuda:{}".format(args.gpu))
print('labels shape:', g.ndata['label'].shape) print("labels shape:", g.ndata["label"].shape)
print("features shape:", g.ndata['feat'].shape) print("features shape:", g.ndata["feat"].shape)
model = GCNNet( model = GCNNet(
in_dim=in_feats, in_dim=in_feats,
...@@ -94,7 +116,7 @@ def main(args, task): ...@@ -94,7 +116,7 @@ def main(args, task):
arch=args.arch, arch=args.arch,
dropout=args.dropout, dropout=args.dropout,
batch_norm=not args.no_batch_norm, batch_norm=not args.no_batch_norm,
aggr=args.aggr aggr=args.aggr,
) )
if cuda: if cuda:
...@@ -102,18 +124,19 @@ def main(args, task): ...@@ -102,18 +124,19 @@ def main(args, task):
# logger and so on # logger and so on
log_dir = save_log_dir(args) log_dir = save_log_dir(args)
logger = Logger(os.path.join(log_dir, 'loggings')) logger = Logger(os.path.join(log_dir, "loggings"))
logger.write(args) logger.write(args)
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr=args.lr)
# set train_nids to cuda tensor # set train_nids to cuda tensor
if cuda: if cuda:
train_nid = torch.from_numpy(train_nid).cuda() train_nid = torch.from_numpy(train_nid).cuda()
print("GPU memory allocated before training(MB)", print(
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024) "GPU memory allocated before training(MB)",
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024,
)
start_time = time.time() start_time = time.time()
best_f1 = -1 best_f1 = -1
...@@ -124,14 +147,18 @@ def main(args, task): ...@@ -124,14 +147,18 @@ def main(args, task):
model.train() model.train()
# forward # forward
pred = model(subg) pred = model(subg)
batch_labels = subg.ndata['label'] batch_labels = subg.ndata["label"]
if multilabel: if multilabel:
loss = F.binary_cross_entropy_with_logits(pred, batch_labels, reduction='sum', loss = F.binary_cross_entropy_with_logits(
weight=subg.ndata['l_n'].unsqueeze(1)) pred,
batch_labels,
reduction="sum",
weight=subg.ndata["l_n"].unsqueeze(1),
)
else: else:
loss = F.cross_entropy(pred, batch_labels, reduction='none') loss = F.cross_entropy(pred, batch_labels, reduction="none")
loss = (subg.ndata['l_n'] * loss).sum() loss = (subg.ndata["l_n"] * loss).sum()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -141,47 +168,77 @@ def main(args, task): ...@@ -141,47 +168,77 @@ def main(args, task):
if j == len(loader) - 1: if j == len(loader) - 1:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
train_f1_mic, train_f1_mac = calc_f1(batch_labels.cpu().numpy(), train_f1_mic, train_f1_mac = calc_f1(
pred.cpu().numpy(), multilabel) batch_labels.cpu().numpy(),
print(f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/" pred.cpu().numpy(),
f"{len(loader)}:training loss", loss.item()) multilabel,
print("Train F1-mic {:.4f}, Train F1-mac {:.4f}".format(train_f1_mic, train_f1_mac)) )
print(
f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/"
f"{len(loader)}:training loss",
loss.item(),
)
print(
"Train F1-mic {:.4f}, Train F1-mac {:.4f}".format(
train_f1_mic, train_f1_mac
)
)
# evaluate # evaluate
model.eval() model.eval()
if epoch % args.val_every == 0: if epoch % args.val_every == 0:
if cpu_flag and cuda: # Only when we have shifted model to gpu and we need to shift it back on cpu if (
model = model.to('cpu') cpu_flag and cuda
): # Only when we have shifted model to gpu and we need to shift it back on cpu
model = model.to("cpu")
val_f1_mic, val_f1_mac = evaluate( val_f1_mic, val_f1_mac = evaluate(
model, g, labels, val_mask, multilabel) model, g, labels, val_mask, multilabel
)
print( print(
"Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(val_f1_mic, val_f1_mac)) "Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(
val_f1_mic, val_f1_mac
)
)
if val_f1_mic > best_f1: if val_f1_mic > best_f1:
best_f1 = val_f1_mic best_f1 = val_f1_mic
print('new best val f1:', best_f1) print("new best val f1:", best_f1)
torch.save(model.state_dict(), os.path.join( torch.save(
log_dir, 'best_model_{}.pkl'.format(task))) model.state_dict(),
os.path.join(log_dir, "best_model_{}.pkl".format(task)),
)
if cpu_flag and cuda: if cpu_flag and cuda:
model.cuda() model.cuda()
end_time = time.time() end_time = time.time()
print(f'training using time {end_time - start_time}') print(f"training using time {end_time - start_time}")
# test # test
if args.use_val: if args.use_val:
model.load_state_dict(torch.load(os.path.join( model.load_state_dict(
log_dir, 'best_model_{}.pkl'.format(task)))) torch.load(os.path.join(log_dir, "best_model_{}.pkl".format(task)))
)
if cpu_flag and cuda: if cpu_flag and cuda:
model = model.to('cpu') model = model.to("cpu")
test_f1_mic, test_f1_mac = evaluate( test_f1_mic, test_f1_mac = evaluate(model, g, labels, test_mask, multilabel)
model, g, labels, test_mask, multilabel) print(
print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac)) "Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(
test_f1_mic, test_f1_mac
)
)
if __name__ == '__main__':
warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser(description='GraphSAINT') if __name__ == "__main__":
parser.add_argument("--task", type=str, default="ppi_n", help="type of tasks") warnings.filterwarnings("ignore")
parser.add_argument("--online", dest='online', action='store_true', help="sampling method in training phase")
parser = argparse.ArgumentParser(description="GraphSAINT")
parser.add_argument(
"--task", type=str, default="ppi_n", help="type of tasks"
)
parser.add_argument(
"--online",
dest="online",
action="store_true",
help="sampling method in training phase",
)
parser.add_argument("--gpu", type=int, default=0, help="the gpu index") parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
task = parser.parse_args().task task = parser.parse_args().task
args = argparse.Namespace(**CONFIG[task]) args = argparse.Namespace(**CONFIG[task])
......
import json import json
import os import os
from functools import namedtuple from functools import namedtuple
import scipy.sparse
from sklearn.preprocessing import StandardScaler
import dgl
import numpy as np import numpy as np
import scipy.sparse
import torch import torch
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from sklearn.preprocessing import StandardScaler
import dgl
class Logger(object): class Logger(object):
'''A custom logger to log stdout to a logging file.''' """A custom logger to log stdout to a logging file."""
def __init__(self, path): def __init__(self, path):
"""Initialize the logger. """Initialize the logger.
...@@ -22,14 +25,14 @@ class Logger(object): ...@@ -22,14 +25,14 @@ class Logger(object):
self.path = path self.path = path
def write(self, s): def write(self, s):
with open(self.path, 'a') as f: with open(self.path, "a") as f:
f.write(str(s)) f.write(str(s))
print(s) print(s)
return return
def save_log_dir(args): def save_log_dir(args):
log_dir = './log/{}/{}'.format(args.dataset, args.log_dir) log_dir = "./log/{}/{}".format(args.dataset, args.log_dir)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
return log_dir return log_dir
...@@ -40,8 +43,9 @@ def calc_f1(y_true, y_pred, multilabel): ...@@ -40,8 +43,9 @@ def calc_f1(y_true, y_pred, multilabel):
y_pred[y_pred <= 0] = 0 y_pred[y_pred <= 0] = 0
else: else:
y_pred = np.argmax(y_pred, axis=1) y_pred = np.argmax(y_pred, axis=1)
return f1_score(y_true, y_pred, average="micro"), \ return f1_score(y_true, y_pred, average="micro"), f1_score(
f1_score(y_true, y_pred, average="macro") y_true, y_pred, average="macro"
)
def evaluate(model, g, labels, mask, multilabel=False): def evaluate(model, g, labels, mask, multilabel=False):
...@@ -50,42 +54,47 @@ def evaluate(model, g, labels, mask, multilabel=False): ...@@ -50,42 +54,47 @@ def evaluate(model, g, labels, mask, multilabel=False):
logits = model(g) logits = model(g)
logits = logits[mask] logits = logits[mask]
labels = labels[mask] labels = labels[mask]
f1_mic, f1_mac = calc_f1(labels.cpu().numpy(), f1_mic, f1_mac = calc_f1(
logits.cpu().numpy(), multilabel) labels.cpu().numpy(), logits.cpu().numpy(), multilabel
)
return f1_mic, f1_mac return f1_mic, f1_mac
# load data of GraphSAINT and convert them to the format of dgl # load data of GraphSAINT and convert them to the format of dgl
def load_data(args, multilabel): def load_data(args, multilabel):
if not os.path.exists('graphsaintdata') and not os.path.exists('data'): if not os.path.exists("graphsaintdata") and not os.path.exists("data"):
raise ValueError("The directory graphsaintdata does not exist!") raise ValueError("The directory graphsaintdata does not exist!")
elif os.path.exists('graphsaintdata') and not os.path.exists('data'): elif os.path.exists("graphsaintdata") and not os.path.exists("data"):
os.rename('graphsaintdata', 'data') os.rename("graphsaintdata", "data")
prefix = "data/{}".format(args.dataset) prefix = "data/{}".format(args.dataset)
DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g']) DataType = namedtuple("Dataset", ["num_classes", "train_nid", "g"])
adj_full = scipy.sparse.load_npz('./{}/adj_full.npz'.format(prefix)).astype(np.bool) adj_full = scipy.sparse.load_npz("./{}/adj_full.npz".format(prefix)).astype(
np.bool
)
g = dgl.from_scipy(adj_full) g = dgl.from_scipy(adj_full)
num_nodes = g.num_nodes() num_nodes = g.num_nodes()
adj_train = scipy.sparse.load_npz('./{}/adj_train.npz'.format(prefix)).astype(np.bool) adj_train = scipy.sparse.load_npz(
"./{}/adj_train.npz".format(prefix)
).astype(np.bool)
train_nid = np.array(list(set(adj_train.nonzero()[0]))) train_nid = np.array(list(set(adj_train.nonzero()[0])))
role = json.load(open('./{}/role.json'.format(prefix))) role = json.load(open("./{}/role.json".format(prefix)))
mask = np.zeros((num_nodes,), dtype=bool) mask = np.zeros((num_nodes,), dtype=bool)
train_mask = mask.copy() train_mask = mask.copy()
train_mask[role['tr']] = True train_mask[role["tr"]] = True
val_mask = mask.copy() val_mask = mask.copy()
val_mask[role['va']] = True val_mask[role["va"]] = True
test_mask = mask.copy() test_mask = mask.copy()
test_mask[role['te']] = True test_mask[role["te"]] = True
feats = np.load('./{}/feats.npy'.format(prefix)) feats = np.load("./{}/feats.npy".format(prefix))
scaler = StandardScaler() scaler = StandardScaler()
scaler.fit(feats[train_nid]) scaler.fit(feats[train_nid])
feats = scaler.transform(feats) feats = scaler.transform(feats)
class_map = json.load(open('./{}/class_map.json'.format(prefix))) class_map = json.load(open("./{}/class_map.json".format(prefix)))
class_map = {int(k): v for k, v in class_map.items()} class_map = {int(k): v for k, v in class_map.items()}
if multilabel: if multilabel:
# Multi-label binary classification # Multi-label binary classification
...@@ -99,11 +108,13 @@ def load_data(args, multilabel): ...@@ -99,11 +108,13 @@ def load_data(args, multilabel):
for k, v in class_map.items(): for k, v in class_map.items():
class_arr[k] = v class_arr[k] = v
g.ndata['feat'] = torch.tensor(feats, dtype=torch.float) g.ndata["feat"] = torch.tensor(feats, dtype=torch.float)
g.ndata['label'] = torch.tensor(class_arr, dtype=torch.float if multilabel else torch.long) g.ndata["label"] = torch.tensor(
g.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool) class_arr, dtype=torch.float if multilabel else torch.long
g.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool) )
g.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool) g.ndata["train_mask"] = torch.tensor(train_mask, dtype=torch.bool)
g.ndata["val_mask"] = torch.tensor(val_mask, dtype=torch.bool)
g.ndata["test_mask"] = torch.tensor(test_mask, dtype=torch.bool)
data = DataType(g=g, num_classes=num_classes, train_nid=train_nid) data = DataType(g=g, num_classes=num_classes, train_nid=train_nid)
return data return data
import os
import copy import copy
import os
import networkx as nx
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, Dataset
import dgl import dgl
import networkx as nx
from torch.utils.data import Dataset, DataLoader
def build_dense_graph(n_particles): def build_dense_graph(n_particles):
...@@ -17,9 +18,9 @@ class MultiBodyDataset(Dataset): ...@@ -17,9 +18,9 @@ class MultiBodyDataset(Dataset):
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
self.zipfile = np.load(self.path) self.zipfile = np.load(self.path)
self.node_state = self.zipfile['data'] self.node_state = self.zipfile["data"]
self.node_label = self.zipfile['label'] self.node_label = self.zipfile["label"]
self.n_particles = self.zipfile['n_particles'] self.n_particles = self.zipfile["n_particles"]
def __len__(self): def __len__(self):
return self.node_state.shape[0] return self.node_state.shape[0]
...@@ -34,25 +35,30 @@ class MultiBodyDataset(Dataset): ...@@ -34,25 +35,30 @@ class MultiBodyDataset(Dataset):
class MultiBodyTrainDataset(MultiBodyDataset): class MultiBodyTrainDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'): def __init__(self, data_path="./data/"):
super(MultiBodyTrainDataset, self).__init__( super(MultiBodyTrainDataset, self).__init__(
data_path+'n_body_train.npz') data_path + "n_body_train.npz"
self.stat_median = self.zipfile['median'] )
self.stat_max = self.zipfile['max'] self.stat_median = self.zipfile["median"]
self.stat_min = self.zipfile['min'] self.stat_max = self.zipfile["max"]
self.stat_min = self.zipfile["min"]
class MultiBodyValidDataset(MultiBodyDataset): class MultiBodyValidDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'): def __init__(self, data_path="./data/"):
super(MultiBodyValidDataset, self).__init__( super(MultiBodyValidDataset, self).__init__(
data_path+'n_body_valid.npz') data_path + "n_body_valid.npz"
)
class MultiBodyTestDataset(MultiBodyDataset): class MultiBodyTestDataset(MultiBodyDataset):
def __init__(self, data_path='./data/'): def __init__(self, data_path="./data/"):
super(MultiBodyTestDataset, self).__init__(data_path+'n_body_test.npz') super(MultiBodyTestDataset, self).__init__(
self.test_traj = self.zipfile['test_traj'] data_path + "n_body_test.npz"
self.first_frame = torch.from_numpy(self.zipfile['first_frame']) )
self.test_traj = self.zipfile["test_traj"]
self.first_frame = torch.from_numpy(self.zipfile["first_frame"])
# Construct fully connected graph # Construct fully connected graph
......
import dgl import copy
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
import dgl.nn as dglnn
import dgl
import dgl.function as fn import dgl.function as fn
import copy import dgl.nn as dglnn
from functools import partial
class MLP(nn.Module): class MLP(nn.Module):
...@@ -17,7 +19,7 @@ class MLP(nn.Module): ...@@ -17,7 +19,7 @@ class MLP(nn.Module):
nn.init.zeros_(layer.bias) nn.init.zeros_(layer.bias)
self.layers.append(nn.Linear(in_feats, hidden)) self.layers.append(nn.Linear(in_feats, hidden))
if num_layers > 2: if num_layers > 2:
for i in range(1, num_layers-1): for i in range(1, num_layers - 1):
layer = nn.Linear(hidden, hidden) layer = nn.Linear(hidden, hidden)
nn.init.normal_(layer.weight, std=0.1) nn.init.normal_(layer.weight, std=0.1)
nn.init.zeros_(layer.bias) nn.init.zeros_(layer.bias)
...@@ -28,7 +30,7 @@ class MLP(nn.Module): ...@@ -28,7 +30,7 @@ class MLP(nn.Module):
self.layers.append(layer) self.layers.append(layer)
def forward(self, x): def forward(self, x):
for l in range(len(self.layers)-1): for l in range(len(self.layers) - 1):
x = self.layers[l](x) x = self.layers[l](x)
x = F.relu(x) x = F.relu(x)
x = self.layers[-1](x) x = self.layers[-1](x)
...@@ -36,7 +38,7 @@ class MLP(nn.Module): ...@@ -36,7 +38,7 @@ class MLP(nn.Module):
class PrepareLayer(nn.Module): class PrepareLayer(nn.Module):
''' """
Generate edge feature for the model input preparation: Generate edge feature for the model input preparation:
as well as do the normalization work. as well as do the normalization work.
Parameters Parameters
...@@ -46,7 +48,7 @@ class PrepareLayer(nn.Module): ...@@ -46,7 +48,7 @@ class PrepareLayer(nn.Module):
stat : dict stat : dict
dictionary which represent the statistics needed for normalization dictionary which represent the statistics needed for normalization
''' """
def __init__(self, node_feats, stat): def __init__(self, node_feats, stat):
super(PrepareLayer, self).__init__() super(PrepareLayer, self).__init__()
...@@ -55,19 +57,21 @@ class PrepareLayer(nn.Module): ...@@ -55,19 +57,21 @@ class PrepareLayer(nn.Module):
self.stat = stat self.stat = stat
def normalize_input(self, node_feature): def normalize_input(self, node_feature):
return (node_feature-self.stat['median'])*(2/(self.stat['max']-self.stat['min'])) return (node_feature - self.stat["median"]) * (
2 / (self.stat["max"] - self.stat["min"])
)
def forward(self, g, node_feature): def forward(self, g, node_feature):
with g.local_scope(): with g.local_scope():
node_feature = self.normalize_input(node_feature) node_feature = self.normalize_input(node_feature)
g.ndata['feat'] = node_feature # Only dynamic feature g.ndata["feat"] = node_feature # Only dynamic feature
g.apply_edges(fn.u_sub_v('feat', 'feat', 'e')) g.apply_edges(fn.u_sub_v("feat", "feat", "e"))
edge_feature = g.edata['e'] edge_feature = g.edata["e"]
return node_feature, edge_feature return node_feature, edge_feature
class InteractionNet(nn.Module): class InteractionNet(nn.Module):
''' """
Simple Interaction Network Simple Interaction Network
One Layer interaction network for stellar multi-body problem simulation, One Layer interaction network for stellar multi-body problem simulation,
it has the ability to simulate number of body motion no more than 12 it has the ability to simulate number of body motion no more than 12
...@@ -78,7 +82,7 @@ class InteractionNet(nn.Module): ...@@ -78,7 +82,7 @@ class InteractionNet(nn.Module):
stat : dict stat : dict
Statistcics for Denormalization Statistcics for Denormalization
''' """
def __init__(self, node_feats, stat): def __init__(self, node_feats, stat):
super(InteractionNet, self).__init__() super(InteractionNet, self).__init__()
...@@ -87,28 +91,34 @@ class InteractionNet(nn.Module): ...@@ -87,28 +91,34 @@ class InteractionNet(nn.Module):
edge_fn = partial(MLP, num_layers=5, hidden=150) edge_fn = partial(MLP, num_layers=5, hidden=150)
node_fn = partial(MLP, num_layers=2, hidden=100) node_fn = partial(MLP, num_layers=2, hidden=100)
self.in_layer = InteractionLayer(node_feats-3, # Use velocity only self.in_layer = InteractionLayer(
node_feats, node_feats - 3, # Use velocity only
out_node_feats=2, node_feats,
out_edge_feats=50, out_node_feats=2,
edge_fn=edge_fn, out_edge_feats=50,
node_fn=node_fn, edge_fn=edge_fn,
mode='n_n') node_fn=node_fn,
mode="n_n",
)
# Denormalize Velocity only # Denormalize Velocity only
def denormalize_output(self, out): def denormalize_output(self, out):
return out*(self.stat['max'][3:5]-self.stat['min'][3:5])/2+self.stat['median'][3:5] return (
out * (self.stat["max"][3:5] - self.stat["min"][3:5]) / 2
+ self.stat["median"][3:5]
)
def forward(self, g, n_feat, e_feat, global_feats, relation_feats): def forward(self, g, n_feat, e_feat, global_feats, relation_feats):
with g.local_scope(): with g.local_scope():
out_n, out_e = self.in_layer( out_n, out_e = self.in_layer(
g, n_feat, e_feat, global_feats, relation_feats) g, n_feat, e_feat, global_feats, relation_feats
)
out_n = self.denormalize_output(out_n) out_n = self.denormalize_output(out_n)
return out_n, out_e return out_n, out_e
class InteractionLayer(nn.Module): class InteractionLayer(nn.Module):
''' """
Implementation of single layer of interaction network Implementation of single layer of interaction network
Parameters Parameters
========== ==========
...@@ -137,20 +147,23 @@ class InteractionLayer(nn.Module): ...@@ -137,20 +147,23 @@ class InteractionLayer(nn.Module):
Function to update node feature in message aggregation Function to update node feature in message aggregation
mode : str mode : str
Type of message should the edge carry Type of message should the edge carry
nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature. nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.
n_n : [src_feat-edge_feat] node feature subtract from each other. n_n : [src_feat-edge_feat] node feature subtract from each other.
''' """
def __init__(self, in_node_feats, def __init__(
in_edge_feats, self,
out_node_feats, in_node_feats,
out_edge_feats, in_edge_feats,
global_feats=1, out_node_feats,
relate_feats=1, out_edge_feats,
edge_fn=nn.Linear, global_feats=1,
node_fn=nn.Linear, relate_feats=1,
mode='nne'): # 'n_n' edge_fn=nn.Linear,
node_fn=nn.Linear,
mode="nne",
): # 'n_n'
super(InteractionLayer, self).__init__() super(InteractionLayer, self).__init__()
self.in_node_feats = in_node_feats self.in_node_feats = in_node_feats
self.in_edge_feats = in_edge_feats self.in_edge_feats = in_edge_feats
...@@ -158,39 +171,44 @@ class InteractionLayer(nn.Module): ...@@ -158,39 +171,44 @@ class InteractionLayer(nn.Module):
self.out_node_feats = out_node_feats self.out_node_feats = out_node_feats
self.mode = mode self.mode = mode
# MLP for message passing # MLP for message passing
input_shape = 2*self.in_node_feats + \ input_shape = (
self.in_edge_feats if mode == 'nne' else self.in_edge_feats+relate_feats 2 * self.in_node_feats + self.in_edge_feats
self.edge_fn = edge_fn(input_shape, if mode == "nne"
self.out_edge_feats) # 50 in IN paper else self.in_edge_feats + relate_feats
)
self.node_fn = node_fn(self.in_node_feats+self.out_edge_feats+global_feats, self.edge_fn = edge_fn(
self.out_node_feats) input_shape, self.out_edge_feats
) # 50 in IN paper
self.node_fn = node_fn(
self.in_node_feats + self.out_edge_feats + global_feats,
self.out_node_feats,
)
# Should be done by apply edge # Should be done by apply edge
def update_edge_fn(self, edges): def update_edge_fn(self, edges):
x = torch.cat([edges.src['feat'], edges.dst['feat'], x = torch.cat(
edges.data['feat']], dim=1) [edges.src["feat"], edges.dst["feat"], edges.data["feat"]], dim=1
ret = F.relu(self.edge_fn( )
x)) if self.mode == 'nne' else self.edge_fn(x) ret = F.relu(self.edge_fn(x)) if self.mode == "nne" else self.edge_fn(x)
return {'e': ret} return {"e": ret}
# Assume agg comes from build in reduce # Assume agg comes from build in reduce
def update_node_fn(self, nodes): def update_node_fn(self, nodes):
x = torch.cat([nodes.data['feat'], nodes.data['agg']], dim=1) x = torch.cat([nodes.data["feat"], nodes.data["agg"]], dim=1)
ret = F.relu(self.node_fn( ret = F.relu(self.node_fn(x)) if self.mode == "nne" else self.node_fn(x)
x)) if self.mode == 'nne' else self.node_fn(x) return {"n": ret}
return {'n': ret}
def forward(self, g, node_feats, edge_feats, global_feats, relation_feats): def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):
# print(node_feats.shape,global_feats.shape) # print(node_feats.shape,global_feats.shape)
g.ndata['feat'] = torch.cat([node_feats, global_feats], dim=1) g.ndata["feat"] = torch.cat([node_feats, global_feats], dim=1)
g.edata['feat'] = torch.cat([edge_feats, relation_feats], dim=1) g.edata["feat"] = torch.cat([edge_feats, relation_feats], dim=1)
if self.mode == 'nne': if self.mode == "nne":
g.apply_edges(self.update_edge_fn) g.apply_edges(self.update_edge_fn)
else: else:
g.edata['e'] = self.edge_fn(g.edata['feat']) g.edata["e"] = self.edge_fn(g.edata["feat"])
g.update_all(fn.copy_e('e', 'msg'), g.update_all(
fn.sum('msg', 'agg'), fn.copy_e("e", "msg"), fn.sum("msg", "agg"), self.update_node_fn
self.update_node_fn) )
return g.ndata['n'], g.edata['e'] return g.ndata["n"], g.edata["e"]
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import os import os
from math import cos, pi, radians, sin
import numpy as np import numpy as np
from math import sin, cos, radians, pi
import argparse
''' """
This adapted from comes from https://github.com/jsikyoon/Interaction-networks_tensorflow This adapted from comes from https://github.com/jsikyoon/Interaction-networks_tensorflow
which generates multi-body dynamic simulation data for Interaction network which generates multi-body dynamic simulation data for Interaction network
''' """
# 5 features on the state [mass,x,y,x_vel,y_vel] # 5 features on the state [mass,x,y,x_vel,y_vel]
fea_num = 5 fea_num = 5
...@@ -20,44 +18,59 @@ G = 10**5 ...@@ -20,44 +18,59 @@ G = 10**5
# time step # time step
diff_t = 0.001 diff_t = 0.001
def init(total_state, n_body, fea_num, orbit): def init(total_state, n_body, fea_num, orbit):
data = np.zeros((total_state, n_body, fea_num), dtype=float) data = np.zeros((total_state, n_body, fea_num), dtype=float)
if(orbit): if orbit:
data[0][0][0] = 100 data[0][0][0] = 100
data[0][0][1:5] = 0.0 data[0][0][1:5] = 0.0
# The position are initialized randomly. # The position are initialized randomly.
for i in range(1, n_body): for i in range(1, n_body):
data[0][i][0] = np.random.rand()*8.98+0.02 data[0][i][0] = np.random.rand() * 8.98 + 0.02
distance = np.random.rand()*90.0+10.0 distance = np.random.rand() * 90.0 + 10.0
theta = np.random.rand()*360 theta = np.random.rand() * 360
theta_rad = pi/2 - radians(theta) theta_rad = pi / 2 - radians(theta)
data[0][i][1] = distance*cos(theta_rad) data[0][i][1] = distance * cos(theta_rad)
data[0][i][2] = distance*sin(theta_rad) data[0][i][2] = distance * sin(theta_rad)
data[0][i][3] = -1*data[0][i][2]/norm(data[0][i][1:3])*( data[0][i][3] = (
G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000 -1
data[0][i][4] = data[0][i][1]/norm(data[0][i][1:3])*( * data[0][i][2]
G*data[0][0][0]/norm(data[0][i][1:3])**2)*distance/1000 / norm(data[0][i][1:3])
* (G * data[0][0][0] / norm(data[0][i][1:3]) ** 2)
* distance
/ 1000
)
data[0][i][4] = (
data[0][i][1]
/ norm(data[0][i][1:3])
* (G * data[0][0][0] / norm(data[0][i][1:3]) ** 2)
* distance
/ 1000
)
else: else:
for i in range(n_body): for i in range(n_body):
data[0][i][0] = np.random.rand()*8.98+0.02 data[0][i][0] = np.random.rand() * 8.98 + 0.02
distance = np.random.rand()*90.0+10.0 distance = np.random.rand() * 90.0 + 10.0
theta = np.random.rand()*360 theta = np.random.rand() * 360
theta_rad = pi/2 - radians(theta) theta_rad = pi / 2 - radians(theta)
data[0][i][1] = distance*cos(theta_rad) data[0][i][1] = distance * cos(theta_rad)
data[0][i][2] = distance*sin(theta_rad) data[0][i][2] = distance * sin(theta_rad)
data[0][i][3] = np.random.rand()*6.0-3.0 data[0][i][3] = np.random.rand() * 6.0 - 3.0
data[0][i][4] = np.random.rand()*6.0-3.0 data[0][i][4] = np.random.rand() * 6.0 - 3.0
return data return data
def norm(x): def norm(x):
return np.sqrt(np.sum(x**2)) return np.sqrt(np.sum(x**2))
def get_f(reciever, sender): def get_f(reciever, sender):
diff = sender[1:3]-reciever[1:3] diff = sender[1:3] - reciever[1:3]
distance = norm(diff) distance = norm(diff)
if(distance < 1): if distance < 1:
distance = 1 distance = 1
return G*reciever[0]*sender[0]/(distance**3)*diff return G * reciever[0] * sender[0] / (distance**3) * diff
# Compute stat according to the paper for normalization # Compute stat according to the paper for normalization
def compute_stats(train_curr): def compute_stats(train_curr):
...@@ -67,38 +80,41 @@ def compute_stats(train_curr): ...@@ -67,38 +80,41 @@ def compute_stats(train_curr):
stat_min = np.quantile(data, 0.05, axis=0) stat_min = np.quantile(data, 0.05, axis=0)
return stat_median, stat_max, stat_min return stat_median, stat_max, stat_min
def calc(cur_state, n_body): def calc(cur_state, n_body):
next_state = np.zeros((n_body, fea_num), dtype=float) next_state = np.zeros((n_body, fea_num), dtype=float)
f_mat = np.zeros((n_body, n_body, 2), dtype=float) f_mat = np.zeros((n_body, n_body, 2), dtype=float)
f_sum = np.zeros((n_body, 2), dtype=float) f_sum = np.zeros((n_body, 2), dtype=float)
acc = np.zeros((n_body, 2), dtype=float) acc = np.zeros((n_body, 2), dtype=float)
for i in range(n_body): for i in range(n_body):
for j in range(i+1, n_body): for j in range(i + 1, n_body):
if(j != i): if j != i:
f = get_f(cur_state[i][:3], cur_state[j][:3]) f = get_f(cur_state[i][:3], cur_state[j][:3])
f_mat[i, j] += f f_mat[i, j] += f
f_mat[j, i] -= f f_mat[j, i] -= f
f_sum[i] = np.sum(f_mat[i], axis=0) f_sum[i] = np.sum(f_mat[i], axis=0)
acc[i] = f_sum[i]/cur_state[i][0] acc[i] = f_sum[i] / cur_state[i][0]
next_state[i][0] = cur_state[i][0] next_state[i][0] = cur_state[i][0]
next_state[i][3:5] = cur_state[i][3:5]+acc[i]*diff_t next_state[i][3:5] = cur_state[i][3:5] + acc[i] * diff_t
next_state[i][1:3] = cur_state[i][1:3]+next_state[i][3:5]*diff_t next_state[i][1:3] = cur_state[i][1:3] + next_state[i][3:5] * diff_t
return next_state return next_state
# The state is [mass,pos_x,pos_y,vel_x,vel_y]* n_body # The state is [mass,pos_x,pos_y,vel_x,vel_y]* n_body
def gen(n_body, num_steps, orbit): def gen(n_body, num_steps, orbit):
# initialization on just first state # initialization on just first state
data = init(num_steps, n_body, fea_num, orbit) data = init(num_steps, n_body, fea_num, orbit)
for i in range(1, num_steps): for i in range(1, num_steps):
data[i] = calc(data[i-1], n_body) data[i] = calc(data[i - 1], n_body)
return data return data
if __name__ == '__main__':
if __name__ == "__main__":
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--num_bodies', type=int, default=6) argparser.add_argument("--num_bodies", type=int, default=6)
argparser.add_argument('--num_traj', type=int, default=10) argparser.add_argument("--num_traj", type=int, default=10)
argparser.add_argument('--steps', type=int, default=1000) argparser.add_argument("--steps", type=int, default=1000)
argparser.add_argument('--data_path', type=str, default='data') argparser.add_argument("--data_path", type=str, default="data")
args = argparser.parse_args() args = argparser.parse_args()
if not os.path.exists(args.data_path): if not os.path.exists(args.data_path):
...@@ -120,8 +136,8 @@ if __name__ == '__main__': ...@@ -120,8 +136,8 @@ if __name__ == '__main__':
label = np.vstack(data_next)[:, :, 3:5] label = np.vstack(data_next)[:, :, 3:5]
shuffle_idx = np.arange(data.shape[0]) shuffle_idx = np.arange(data.shape[0])
np.random.shuffle(shuffle_idx) np.random.shuffle(shuffle_idx)
train_split = int(0.9*data.shape[0]) train_split = int(0.9 * data.shape[0])
valid_split = train_split+300 valid_split = train_split + 300
data = data[shuffle_idx] data = data[shuffle_idx]
label = label[shuffle_idx] label = label[shuffle_idx]
...@@ -134,24 +150,30 @@ if __name__ == '__main__': ...@@ -134,24 +150,30 @@ if __name__ == '__main__':
test_data = data[valid_split:] test_data = data[valid_split:]
test_label = label[valid_split:] test_label = label[valid_split:]
np.savez(args.data_path+'/n_body_train.npz', np.savez(
data=train_data, args.data_path + "/n_body_train.npz",
label=train_label, data=train_data,
n_particles=args.num_bodies, label=train_label,
median=stat_median, n_particles=args.num_bodies,
max=stat_max, median=stat_median,
min=stat_min) max=stat_max,
min=stat_min,
np.savez(args.data_path+'/n_body_valid.npz', )
data=valid_data,
label=valid_label, np.savez(
n_particles=args.num_bodies) args.data_path + "/n_body_valid.npz",
data=valid_data,
label=valid_label,
n_particles=args.num_bodies,
)
test_traj = gen(args.num_bodies, args.steps, True) test_traj = gen(args.num_bodies, args.steps, True)
np.savez(args.data_path+'/n_body_test.npz', np.savez(
data=test_data, args.data_path + "/n_body_test.npz",
label=test_label, data=test_data,
n_particles=args.num_bodies, label=test_label,
first_frame=test_traj[0], n_particles=args.num_bodies,
test_traj=test_traj) first_frame=test_traj[0],
test_traj=test_traj,
)
import time
import argparse import argparse
import time
import traceback import traceback
import networkx as nx
import numpy as np import numpy as np
import torch import torch
from dataloader import (
MultiBodyGraphCollator,
MultiBodyTestDataset,
MultiBodyTrainDataset,
MultiBodyValidDataset,
)
from models import MLP, InteractionNet, PrepareLayer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import networkx as nx from utils import make_video
import dgl
from models import MLP, InteractionNet, PrepareLayer import dgl
from dataloader import MultiBodyGraphCollator, MultiBodyTrainDataset,\
MultiBodyValidDataset, MultiBodyTestDataset
from utils import make_video
def train(optimizer, loss_fn,reg_fn, model, prep, dataloader, lambda_reg, device): def train(
optimizer, loss_fn, reg_fn, model, prep, dataloader, lambda_reg, device
):
total_loss = 0 total_loss = 0
model.train() model.train()
for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader): for i, (graph_batch, data_batch, label_batch) in enumerate(dataloader):
...@@ -25,21 +31,27 @@ def train(optimizer, loss_fn,reg_fn, model, prep, dataloader, lambda_reg, device ...@@ -25,21 +31,27 @@ def train(optimizer, loss_fn,reg_fn, model, prep, dataloader, lambda_reg, device
node_feat, edge_feat = prep(graph_batch, data_batch) node_feat, edge_feat = prep(graph_batch, data_batch)
dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device) dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
v_pred, out_e = model(graph_batch, node_feat[:, 3:5].float( v_pred, out_e = model(
), edge_feat.float(), dummy_global, dummy_relation) graph_batch,
node_feat[:, 3:5].float(),
edge_feat.float(),
dummy_global,
dummy_relation,
)
loss = loss_fn(v_pred, label_batch) loss = loss_fn(v_pred, label_batch)
total_loss += float(loss) total_loss += float(loss)
zero_target = torch.zeros_like(out_e) zero_target = torch.zeros_like(out_e)
loss = loss + lambda_reg*reg_fn(out_e, zero_target) loss = loss + lambda_reg * reg_fn(out_e, zero_target)
reg_loss = 0 reg_loss = 0
for param in model.parameters(): for param in model.parameters():
reg_loss = reg_loss + lambda_reg * \ reg_loss = reg_loss + lambda_reg * reg_fn(
reg_fn(param, torch.zeros_like( param, torch.zeros_like(param).float().to(device)
param).float().to(device)) )
loss = loss + reg_loss loss = loss + reg_loss
loss.backward() loss.backward()
optimizer.step() optimizer.step()
return total_loss/(i+1) return total_loss / (i + 1)
# One step evaluation # One step evaluation
...@@ -52,15 +64,19 @@ def eval(loss_fn, model, prep, dataloader, device): ...@@ -52,15 +64,19 @@ def eval(loss_fn, model, prep, dataloader, device):
data_batch = data_batch.to(device) data_batch = data_batch.to(device)
label_batch = label_batch.to(device) label_batch = label_batch.to(device)
node_feat, edge_feat = prep(graph_batch, data_batch) node_feat, edge_feat = prep(graph_batch, data_batch)
dummy_relation = torch.zeros( dummy_relation = torch.zeros(edge_feat.shape[0], 1).float().to(device)
edge_feat.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feat.shape[0], 1).float().to(device)
dummy_global = torch.zeros( v_pred, _ = model(
node_feat.shape[0], 1).float().to(device) graph_batch,
v_pred, _ = model(graph_batch, node_feat[:, 3:5].float( node_feat[:, 3:5].float(),
), edge_feat.float(), dummy_global, dummy_relation) edge_feat.float(),
dummy_global,
dummy_relation,
)
loss = loss_fn(v_pred, label_batch) loss = loss_fn(v_pred, label_batch)
total_loss += float(loss) total_loss += float(loss)
return total_loss/(i+1) return total_loss / (i + 1)
# Rollout Evaluation based in initial state # Rollout Evaluation based in initial state
# Need to integrate # Need to integrate
...@@ -74,42 +90,58 @@ def eval_rollout(model, prep, initial_frame, n_object, device): ...@@ -74,42 +90,58 @@ def eval_rollout(model, prep, initial_frame, n_object, device):
model.eval() model.eval()
for step in range(100): for step in range(100):
node_feats, edge_feats = prep(graph, current_frame) node_feats, edge_feats = prep(graph, current_frame)
dummy_relation = torch.zeros( dummy_relation = torch.zeros(edge_feats.shape[0], 1).float().to(device)
edge_feats.shape[0], 1).float().to(device) dummy_global = torch.zeros(node_feats.shape[0], 1).float().to(device)
dummy_global = torch.zeros( v_pred, _ = model(
node_feats.shape[0], 1).float().to(device) graph,
v_pred, _ = model(graph, node_feats[:, 3:5].float( node_feats[:, 3:5].float(),
), edge_feats.float(), dummy_global, dummy_relation) edge_feats.float(),
current_frame[:, [1, 2]] += v_pred*0.001 dummy_global,
dummy_relation,
)
current_frame[:, [1, 2]] += v_pred * 0.001
current_frame[:, 3:5] = v_pred current_frame[:, 3:5] = v_pred
pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy()) pos_buffer.append(current_frame[:, [1, 2]].cpu().numpy())
pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1) pos_buffer = np.vstack(pos_buffer).reshape(100, n_object, -1)
make_video(pos_buffer, 'video_model.mp4') make_video(pos_buffer, "video_model.mp4")
if __name__ == '__main__': if __name__ == "__main__":
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--lr', type=float, default=0.001, argparser.add_argument(
help='learning rate') "--lr", type=float, default=0.001, help="learning rate"
argparser.add_argument('--epochs', type=int, default=40000, )
help='Number of epochs in training') argparser.add_argument(
argparser.add_argument('--lambda_reg', type=float, default=0.001, "--epochs", type=int, default=40000, help="Number of epochs in training"
help='regularization weight') )
argparser.add_argument('--gpu', type=int, default=-1, argparser.add_argument(
help='gpu device code, -1 means cpu') "--lambda_reg", type=float, default=0.001, help="regularization weight"
argparser.add_argument('--batch_size', type=int, default=100, )
help='size of each mini batch') argparser.add_argument(
argparser.add_argument('--num_workers', type=int, default=0, "--gpu", type=int, default=-1, help="gpu device code, -1 means cpu"
help='number of workers for dataloading') )
argparser.add_argument('--visualize', action='store_true', default=False, argparser.add_argument(
help='Whether enable trajectory rollout mode for visualization') "--batch_size", type=int, default=100, help="size of each mini batch"
)
argparser.add_argument(
"--num_workers",
type=int,
default=0,
help="number of workers for dataloading",
)
argparser.add_argument(
"--visualize",
action="store_true",
default=False,
help="Whether enable trajectory rollout mode for visualization",
)
args = argparser.parse_args() args = argparser.parse_args()
# Select Device to be CPU or GPU # Select Device to be CPU or GPU
if args.gpu != -1: if args.gpu != -1:
device = torch.device('cuda:{}'.format(args.gpu)) device = torch.device("cuda:{}".format(args.gpu))
else: else:
device = torch.device('cpu') device = torch.device("cpu")
train_data = MultiBodyTrainDataset() train_data = MultiBodyTrainDataset()
valid_data = MultiBodyValidDataset() valid_data = MultiBodyValidDataset()
...@@ -117,22 +149,51 @@ if __name__ == '__main__': ...@@ -117,22 +149,51 @@ if __name__ == '__main__':
collator = MultiBodyGraphCollator(train_data.n_particles) collator = MultiBodyGraphCollator(train_data.n_particles)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers) train_data,
args.batch_size,
True,
collate_fn=collator,
num_workers=args.num_workers,
)
valid_dataloader = DataLoader( valid_dataloader = DataLoader(
valid_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers) valid_data,
args.batch_size,
True,
collate_fn=collator,
num_workers=args.num_workers,
)
test_full_dataloader = DataLoader( test_full_dataloader = DataLoader(
test_data, args.batch_size, True, collate_fn=collator, num_workers=args.num_workers) test_data,
args.batch_size,
True,
collate_fn=collator,
num_workers=args.num_workers,
)
node_feats = 5 node_feats = 5
stat = {'median': torch.from_numpy(train_data.stat_median).to(device), stat = {
'max': torch.from_numpy(train_data.stat_max).to(device), "median": torch.from_numpy(train_data.stat_median).to(device),
'min': torch.from_numpy(train_data.stat_min).to(device)} "max": torch.from_numpy(train_data.stat_max).to(device),
print("Weight: ", train_data.stat_median[0], "min": torch.from_numpy(train_data.stat_min).to(device),
train_data.stat_max[0], train_data.stat_min[0]) }
print("Position: ", train_data.stat_median[[ print(
1, 2]], train_data.stat_max[[1, 2]], train_data.stat_min[[1, 2]]) "Weight: ",
print("Velocity: ", train_data.stat_median[[ train_data.stat_median[0],
3, 4]], train_data.stat_max[[3, 4]], train_data.stat_min[[3, 4]]) train_data.stat_max[0],
train_data.stat_min[0],
)
print(
"Position: ",
train_data.stat_median[[1, 2]],
train_data.stat_max[[1, 2]],
train_data.stat_min[[1, 2]],
)
print(
"Velocity: ",
train_data.stat_median[[3, 4]],
train_data.stat_max[[3, 4]],
train_data.stat_min[[3, 4]],
)
prepare_layer = PrepareLayer(node_feats, stat).to(device) prepare_layer = PrepareLayer(node_feats, stat).to(device)
interaction_net = InteractionNet(node_feats, stat).to(device) interaction_net = InteractionNet(node_feats, stat).to(device)
...@@ -141,24 +202,50 @@ if __name__ == '__main__': ...@@ -141,24 +202,50 @@ if __name__ == '__main__':
state_dict = interaction_net.state_dict() state_dict = interaction_net.state_dict()
loss_fn = torch.nn.MSELoss() loss_fn = torch.nn.MSELoss()
reg_fn = torch.nn.MSELoss(reduction='sum') reg_fn = torch.nn.MSELoss(reduction="sum")
try: try:
for e in range(args.epochs): for e in range(args.epochs):
last_t = time.time() last_t = time.time()
loss = train(optimizer, loss_fn,reg_fn, interaction_net, loss = train(
prepare_layer, train_dataloader, args.lambda_reg, device) optimizer,
print("Epoch time: ", time.time()-last_t) loss_fn,
reg_fn,
interaction_net,
prepare_layer,
train_dataloader,
args.lambda_reg,
device,
)
print("Epoch time: ", time.time() - last_t)
if e % 1 == 0: if e % 1 == 0:
valid_loss = eval(loss_fn, interaction_net, valid_loss = eval(
prepare_layer, valid_dataloader, device) loss_fn,
interaction_net,
prepare_layer,
valid_dataloader,
device,
)
test_full_loss = eval( test_full_loss = eval(
loss_fn, interaction_net, prepare_layer, test_full_dataloader, device) loss_fn,
print("Epoch: {}.Loss: Valid: {} Full: {}".format( interaction_net,
e, valid_loss, test_full_loss)) prepare_layer,
test_full_dataloader,
device,
)
print(
"Epoch: {}.Loss: Valid: {} Full: {}".format(
e, valid_loss, test_full_loss
)
)
except: except:
traceback.print_exc() traceback.print_exc()
finally: finally:
if args.visualize: if args.visualize:
eval_rollout(interaction_net, prepare_layer, eval_rollout(
test_data.first_frame, test_data.n_particles, device) interaction_net,
make_video(test_data.test_traj[:100, :, [1, 2]], 'video_truth.mp4') prepare_layer,
test_data.first_frame,
test_data.n_particles,
device,
)
make_video(test_data.test_traj[:100, :, [1, 2]], "video_truth.mp4")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment