Unverified Commit 168a88e5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Sampling] NodeDataLoader for node classification (#1635)



* neighbor sampler data loader first commit

* more commit

* nodedataloader

* fix

* update RGCN example

* update OGB

* fixes

* fix minibatch RGCN crashing with self loop

* reverting gatconv test code

* fix

* change to new solution that doesn't require tf dataloader

* fix

* lint

* fix

* fixes

* change doc

* fix docstring

* docstring fixes

* return seeds and input nodes from data loader

* fixes

* fix test

* fix windows build problem

* add pytorch wrapper

* fixes

* add pytorch wrapper

* add unit test

* add -1 support to sample_neighbors & fix docstrings

* docstring fix

* lint

* add minibatch rgcn evaluations
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 781b9cec
.. _apinodeflow: .. _apinodeflow:
dgl.nodeflow dgl.nodeflow (Deprecating)
============== ==============
.. warning::
This module is going to be deprecated in favor of :ref:`api-sampling`.
.. currentmodule:: dgl .. currentmodule:: dgl
.. autoclass:: NodeFlow .. autoclass:: NodeFlow
......
.. apisampler .. apisampler
dgl.contrib.sampling dgl.contrib.sampling (Deprecating)
====================== ======================
.. warning::
This module is going to be deprecated in favor of :ref:`api-sampling`.
Module for sampling algorithms on graph. Each algorithm is implemented as a Module for sampling algorithms on graph. Each algorithm is implemented as a
data loader, which produces sampled subgraphs (called Nodeflow) at each data loader, which produces sampled subgraphs (called Nodeflow) at each
iteration. iteration.
......
...@@ -25,8 +25,23 @@ Neighbor sampling functions ...@@ -25,8 +25,23 @@ Neighbor sampling functions
sample_neighbors sample_neighbors
select_topk select_topk
PyTorch DataLoaders with neighborhood sampling
----------------------------------------------
.. autoclass:: pytorch.NeighborSamplerNodeDataLoader
Builtin sampler classes for more complicated sampling algorithms Builtin sampler classes for more complicated sampling algorithms
---------------------------------------------------------------- ----------------------------------------------------------------
.. autoclass:: RandomWalkNeighborSampler .. autoclass:: RandomWalkNeighborSampler
.. autoclass:: PinSAGESampler .. autoclass:: PinSAGESampler
Neighborhood samplers for multilayer GNNs
-----------------------------------------
.. autoclass:: MultiLayerNeighborSampler
Data loaders for minibatch iteration
------------------------------------
.. autoclass:: NodeCollator
Abstract class for neighborhood sampler
---------------------------------------
.. autoclass:: BlockSampler
...@@ -11,14 +11,22 @@ Common algorithms on graphs. ...@@ -11,14 +11,22 @@ Common algorithms on graphs.
:toctree: ../../generated/ :toctree: ../../generated/
line_graph line_graph
khop_adj
khop_graph
reverse reverse
to_simple_graph to_simple_graph
to_bidirected to_bidirected
khop_adj
khop_graph
laplacian_lambda_max laplacian_lambda_max
knn_graph knn_graph
segmented_knn_graph segmented_knn_graph
add_self_loop add_self_loop
remove_self_loop remove_self_loop
metapath_reachable_graph metapath_reachable_graph
compact_graphs
to_block
to_simple
in_subgraph
out_subgraph
remove_edges
as_immutable_graph
as_heterograph
...@@ -110,6 +110,7 @@ a useful manual for in-depth developers. ...@@ -110,6 +110,7 @@ a useful manual for in-depth developers.
api/python/propagate api/python/propagate
api/python/transform api/python/transform
api/python/sampler api/python/sampler
api/python/sampling
api/python/data api/python/data
api/python/nodeflow api/python/nodeflow
api/python/random api/python/random
......
...@@ -21,7 +21,27 @@ import torch.multiprocessing as mp ...@@ -21,7 +21,27 @@ import torch.multiprocessing as mp
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pyinstrument import Profiler from pyinstrument import Profiler
from train_sampling import run, NeighborSampler, SAGE, compute_acc, evaluate, load_subtensor from train_sampling import run, SAGE, compute_acc, evaluate, load_subtensor
class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
def start_server(args): def start_server(args):
serv = dgl.distributed.DistGraphServer(args.id, args.ip_config, args.num_client, serv = dgl.distributed.DistGraphServer(args.id, args.ip_config, args.num_client,
......
...@@ -18,28 +18,6 @@ import traceback ...@@ -18,28 +18,6 @@ import traceback
from load_graph import load_reddit, load_ogb from load_graph import load_reddit, load_ogb
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
...@@ -94,11 +72,18 @@ class SAGE(nn.Module): ...@@ -94,11 +72,18 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size): sampler = dgl.sampling.MultiLayerNeighborSampler([None])
end = start + batch_size dataloader = dgl.sampling.NodeDataLoader(
batch_nodes = nodes[start:end] g,
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) th.arange(g.number_of_nodes()),
input_nodes = block.srcdata[dgl.NID] sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[:block.number_of_dst_nodes()]
...@@ -107,7 +92,7 @@ class SAGE(nn.Module): ...@@ -107,7 +92,7 @@ class SAGE(nn.Module):
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
y[start:end] = h.cpu() y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -161,15 +146,14 @@ def run(args, device, data): ...@@ -161,15 +146,14 @@ def run(args, device, data):
train_nid = th.nonzero(train_mask, as_tuple=True)[0] train_nid = th.nonzero(train_mask, as_tuple=True)[0]
val_nid = th.nonzero(val_mask, as_tuple=True)[0] val_nid = th.nonzero(val_mask, as_tuple=True)[0]
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.sampling.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( sampler = dgl.sampling.MultiLayerNeighborSampler(
dataset=train_nid.numpy(), [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -189,14 +173,9 @@ def run(args, device, data): ...@@ -189,14 +173,9 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
for step, blocks in enumerate(dataloader): for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time() tic_step = time.time()
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device) batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device)
......
...@@ -17,27 +17,6 @@ import traceback ...@@ -17,27 +17,6 @@ import traceback
from utils import thread_wrapped_func from utils import thread_wrapped_func
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
...@@ -92,11 +71,18 @@ class SAGE(nn.Module): ...@@ -92,11 +71,18 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size): sampler = dgl.sampling.MultiLayerNeighborSampler([None])
end = start + batch_size dataloader = dgl.sampling.NodeDataLoader(
batch_nodes = nodes[start:end] g,
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) th.arange(g.number_of_nodes()),
input_nodes = block.srcdata[dgl.NID] sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[:block.number_of_dst_nodes()]
...@@ -105,7 +91,7 @@ class SAGE(nn.Module): ...@@ -105,7 +91,7 @@ class SAGE(nn.Module):
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
y[start:end] = h.cpu() y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -177,14 +163,14 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -177,14 +163,14 @@ def run(proc_id, n_gpus, args, devices, data):
# Split train_nid # Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id] train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id]
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( sampler = dgl.sampling.MultiLayerNeighborSampler(
dataset=train_nid.numpy(), [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -206,15 +192,10 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -206,15 +192,10 @@ def run(proc_id, n_gpus, args, devices, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
for step, blocks in enumerate(dataloader): for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
if proc_id == 0: if proc_id == 0:
tic_step = time.time() tic_step = time.time()
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, dev_id) batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, dev_id)
......
...@@ -101,11 +101,18 @@ class GAT(nn.Module): ...@@ -101,11 +101,18 @@ class GAT(nn.Module):
else: else:
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size): sampler = dgl.sampling.MultiLayerNeighborSampler([None])
end = start + batch_size dataloader = dgl.sampling.NodeDataLoader(
batch_nodes = nodes[start:end] g,
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) th.arange(g.number_of_nodes()),
input_nodes = block.srcdata[dgl.NID] sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[:block.number_of_dst_nodes()]
...@@ -115,7 +122,8 @@ class GAT(nn.Module): ...@@ -115,7 +122,8 @@ class GAT(nn.Module):
h = layer(block, (h, h_dst)) h = layer(block, (h, h_dst))
h = h.mean(1) h = h.mean(1)
h = h.log_softmax(dim=-1) h = h.log_softmax(dim=-1)
y[start:end] = h.cpu()
y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -167,14 +175,14 @@ def run(args, device, data): ...@@ -167,14 +175,14 @@ def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, num_heads = data train_nid, val_nid, test_nid, in_feats, labels, n_classes, g, num_heads = data
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( sampler = dgl.sampling.MultiLayerNeighborSampler(
dataset=train_nid.numpy(), [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -194,14 +202,9 @@ def run(args, device, data): ...@@ -194,14 +202,9 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
for step, blocks in enumerate(dataloader): for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time() tic_step = time.time()
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device) batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device)
......
...@@ -17,30 +17,6 @@ import tqdm ...@@ -17,30 +17,6 @@ import tqdm
import traceback import traceback
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
if fanout == 0:
frontier = dgl.in_subgraph(self.g, seeds)
else:
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class SAGE(nn.Module): class SAGE(nn.Module):
def __init__(self, def __init__(self,
in_feats, in_feats,
...@@ -94,11 +70,18 @@ class SAGE(nn.Module): ...@@ -94,11 +70,18 @@ class SAGE(nn.Module):
for l, layer in enumerate(self.layers): for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size): sampler = dgl.sampling.MultiLayerNeighborSampler([None])
end = start + batch_size dataloader = dgl.sampling.NodeDataLoader(
batch_nodes = nodes[start:end] g,
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) th.arange(g.number_of_nodes()),
input_nodes = block.srcdata[dgl.NID] sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
h_dst = h[:block.number_of_dst_nodes()] h_dst = h[:block.number_of_dst_nodes()]
...@@ -107,7 +90,7 @@ class SAGE(nn.Module): ...@@ -107,7 +90,7 @@ class SAGE(nn.Module):
h = self.activation(h) h = self.activation(h)
h = self.dropout(h) h = self.dropout(h)
y[start:end] = h.cpu() y[output_nodes] = h.cpu()
x = y x = y
return y return y
...@@ -159,14 +142,14 @@ def run(args, device, data): ...@@ -159,14 +142,14 @@ def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, labels, n_classes, g = data train_nid, val_nid, test_nid, in_feats, labels, n_classes, g = data
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')])
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( sampler = dgl.sampling.MultiLayerNeighborSampler(
dataset=train_nid.numpy(), [int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.sampling.NodeDataLoader(
g,
train_nid,
sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
...@@ -188,14 +171,9 @@ def run(args, device, data): ...@@ -188,14 +171,9 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of # Loop over the dataloader to sample the computation dependency graph as a list of
# blocks. # blocks.
for step, blocks in enumerate(dataloader): for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
tic_step = time.time() tic_step = time.time()
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]
# Load the input features as well as output labels # Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device) batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, device)
......
...@@ -16,59 +16,32 @@ import dgl ...@@ -16,59 +16,32 @@ import dgl
from dgl.data.rdf import AIFB, MUTAG, BGS, AM from dgl.data.rdf import AIFB, MUTAG, BGS, AM
from model import EntityClassify, RelGraphEmbed from model import EntityClassify, RelGraphEmbed
class HeteroNeighborSampler: def extract_embed(node_embed, input_nodes):
"""Neighbor sampler on heterogeneous graphs
Parameters
----------
g : DGLHeteroGraph
Full graph
category : str
Category name of the seed nodes.
fanouts : list of int
Fanout of each hop starting from the seed nodes. If a fanout is None,
sample full neighbors.
"""
def __init__(self, g, category, fanouts):
self.g = g
self.category = category
self.fanouts = fanouts
def sample_blocks(self, seeds):
blocks = []
seeds = {self.category : th.tensor(seeds).long()}
cur = seeds
for fanout in self.fanouts:
if fanout is None:
frontier = dgl.in_subgraph(self.g, cur)
else:
frontier = dgl.sampling.sample_neighbors(self.g, cur, fanout)
block = dgl.to_block(frontier, cur)
cur = {}
for ntype in block.srctypes:
cur[ntype] = block.srcnodes[ntype].data[dgl.NID]
blocks.insert(0, block)
return seeds, blocks
def extract_embed(node_embed, block):
emb = {} emb = {}
for ntype in block.srctypes: for ntype, nid in input_nodes.items():
nid = block.srcnodes[ntype].data[dgl.NID] nid = input_nodes[ntype]
emb[ntype] = node_embed[ntype][nid] emb[ntype] = node_embed[ntype][nid]
return emb return emb
def evaluate(model, loader, node_embed, labels, category, use_cuda):
def evaluate(model, seeds, blocks, node_embed, labels, category, use_cuda):
model.eval() model.eval()
emb = extract_embed(node_embed, blocks[0]) total_loss = 0
total_acc = 0
count = 0
for input_nodes, seeds, blocks in loader:
seeds = seeds[category]
emb = extract_embed(node_embed, input_nodes)
lbl = labels[seeds] lbl = labels[seeds]
if use_cuda: if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()} emb = {k : e.cuda() for k, e in emb.items()}
lbl = lbl.cuda() lbl = lbl.cuda()
logits = model(emb, blocks)[category] logits = model(emb, blocks)[category]
loss = F.cross_entropy(logits, lbl) loss = F.cross_entropy(logits, lbl)
acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds) acc = th.sum(logits.argmax(dim=1) == lbl).item()
return loss, acc total_loss += loss.item() * len(seeds)
total_acc += acc
count += len(seeds)
return total_loss / count, total_acc / count
def main(args): def main(args):
# load graph data # load graph data
...@@ -122,20 +95,23 @@ def main(args): ...@@ -122,20 +95,23 @@ def main(args):
model.cuda() model.cuda()
# train sampler # train sampler
sampler = HeteroNeighborSampler(g, category, [args.fanout] * args.n_layers) sampler = dgl.sampling.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
loader = DataLoader(dataset=train_idx.numpy(), loader = dgl.sampling.NodeDataLoader(
batch_size=args.batch_size, g, {category: train_idx}, sampler,
collate_fn=sampler.sample_blocks, batch_size=args.batch_size, shuffle=True, num_workers=0)
shuffle=True,
num_workers=0)
# validation sampler # validation sampler
val_sampler = HeteroNeighborSampler(g, category, [None] * args.n_layers) val_sampler = dgl.sampling.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
_, val_blocks = val_sampler.sample_blocks(val_idx) val_loader = dgl.sampling.NodeDataLoader(
g, {category: val_idx}, val_sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
# test sampler # test sampler
test_sampler = HeteroNeighborSampler(g, category, [None] * args.n_layers)
_, test_blocks = test_sampler.sample_blocks(test_idx) test_sampler = dgl.sampling.MultiLayerNeighborSampler([args.fanout] * args.n_layers)
test_loader = dgl.sampling.NodeDataLoader(
g, {category: test_idx}, test_sampler,
batch_size=args.batch_size, shuffle=True, num_workers=0)
# optimizer # optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters()) all_params = itertools.chain(model.parameters(), embed_layer.parameters())
...@@ -150,10 +126,11 @@ def main(args): ...@@ -150,10 +126,11 @@ def main(args):
if epoch > 3: if epoch > 3:
t0 = time.time() t0 = time.time()
for i, (seeds, blocks) in enumerate(loader): for i, (input_nodes, seeds, blocks) in enumerate(loader):
seeds = seeds[category] # we only predict the nodes with type "category"
batch_tic = time.time() batch_tic = time.time()
emb = extract_embed(node_embed, blocks[0]) emb = extract_embed(node_embed, input_nodes)
lbl = labels[seeds[category]] lbl = labels[seeds]
if use_cuda: if use_cuda:
emb = {k : e.cuda() for k, e in emb.items()} emb = {k : e.cuda() for k, e in emb.items()}
lbl = lbl.cuda() lbl = lbl.cuda()
...@@ -162,22 +139,26 @@ def main(args): ...@@ -162,22 +139,26 @@ def main(args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds[category]) train_acc = th.sum(logits.argmax(dim=1) == lbl).item() / len(seeds)
print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}". print("Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}".
format(epoch, i, train_acc, loss.item(), time.time() - batch_tic)) format(epoch, i, train_acc, loss.item(), time.time() - batch_tic))
if epoch > 3: if epoch > 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
val_loss, val_acc = evaluate(model, val_idx, val_blocks, node_embed, labels, category, use_cuda) val_loss, val_acc = evaluate(model, val_loader, node_embed, labels, category, use_cuda)
print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}". print("Epoch {:05d} | Valid Acc: {:.4f} | Valid loss: {:.4f} | Time: {:.4f}".
format(epoch, val_acc, val_loss.item(), np.average(dur))) format(epoch, val_acc, val_loss, np.average(dur)))
print() print()
if args.model_path is not None: if args.model_path is not None:
th.save(model.state_dict(), args.model_path) th.save(model.state_dict(), args.model_path)
test_loss, test_acc = evaluate(model, test_idx, test_blocks, node_embed, labels, category, use_cuda) output = model.inference(
print("Test Acc: {:.4f} | Test loss: {:.4f}".format(test_acc, test_loss.item())) g, args.batch_size, 'cuda' if use_cuda else 'cpu', 0, node_embed)
test_pred = output[category][test_idx]
test_labels = labels[test_idx]
test_acc = (test_pred.argmax(1) == test_labels).float().mean()
print("Test Acc: {:.4f}".format(test_acc))
print() print()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -4,8 +4,9 @@ from collections import defaultdict ...@@ -4,8 +4,9 @@ from collections import defaultdict
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import tqdm
class RelGraphConvLayer(nn.Module): class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer. r"""Relational graph convolution layer.
...@@ -101,9 +102,17 @@ class RelGraphConvLayer(nn.Module): ...@@ -101,9 +102,17 @@ class RelGraphConvLayer(nn.Module):
else: else:
wdict = {} wdict = {}
hs = self.conv(g, inputs, mod_kwargs=wdict) hs = self.conv(g, inputs, mod_kwargs=wdict)
if isinstance(inputs, tuple):
# minibatch training
inputs_dst = inputs[1]
else:
# full graph training
inputs_dst = inputs
def _apply(ntype, h): def _apply(ntype, h):
if self.self_loop: if self.self_loop:
h = h + th.matmul(inputs[ntype], self.loop_weight) h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias: if self.bias:
h = h + self.h_bias h = h + self.h_bias
if self.activation: if self.activation:
...@@ -192,12 +201,57 @@ class EntityClassify(nn.Module): ...@@ -192,12 +201,57 @@ class EntityClassify(nn.Module):
self_loop=self.use_self_loop)) self_loop=self.use_self_loop))
def forward(self, h=None, blocks=None): def forward(self, h=None, blocks=None):
if blocks is None:
# full graph training
blocks = [self.g] * len(self.layers)
if h is None: if h is None:
# full graph training # full graph training
h = self.embed_layer() h = self.embed_layer()
if blocks is None:
# full graph training
for layer in self.layers:
h = layer(self.g, h)
else:
# minibatch training
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
h = layer(block, h) h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()}
h = layer(block, (h, h_dst))
return h return h
def inference(self, g, batch_size, device, num_workers, x=None):
"""Minibatch inference of final representation over all node types.
***NOTE***
For node classification, the model is trained to predict on only one node type's
label. Therefore, only that type's final representation is meaningful.
"""
if x is None:
x = self.embed_layer()
for l, layer in enumerate(self.layers):
y = {
k: th.zeros(
g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim)
for k in g.ntypes}
sampler = dgl.sampling.MultiLayerNeighborSampler([None])
dataloader = dgl.sampling.NodeDataLoader(
g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0]
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h_dst = {k: v[:block.number_of_dst_nodes(k)] for k, v in h.items()}
h = layer(block, (h, h_dst))
for k in h.keys():
y[k][output_nodes[k]] = h[k].cpu()
x = y
return y
...@@ -1675,6 +1675,9 @@ class DGLHeteroGraph(object): ...@@ -1675,6 +1675,9 @@ class DGLHeteroGraph(object):
>>> g.find_edges([0, 2]) >>> g.find_edges([0, 2])
(tensor([0, 1]), tensor([0, 2])) (tensor([0, 1]), tensor([0, 2]))
""" """
if len(eid) == 0:
return F.tensor([], dtype=self.idtype), F.tensor([], dtype=self.idtype)
check_same_dtype(self._idtype_str, eid) check_same_dtype(self._idtype_str, eid)
if F.is_tensor(eid): if F.is_tensor(eid):
max_eid = F.max(eid, dim=0) max_eid = F.max(eid, dim=0)
......
...@@ -7,6 +7,7 @@ from mxnet import gluon ...@@ -7,6 +7,7 @@ from mxnet import gluon
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair
class GraphConv(gluon.Block): class GraphConv(gluon.Block):
r"""Apply graph convolution over an input signal. r"""Apply graph convolution over an input signal.
...@@ -109,8 +110,15 @@ class GraphConv(gluon.Block): ...@@ -109,8 +110,15 @@ class GraphConv(gluon.Block):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray feat : mxnet.NDArray or pair of mxnet.NDArray
The input feature. If a single tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Note that in the special case of graph convolutional networks, if a pair of
tensors is given, the latter element will not participate in computation.
weight : torch.Tensor, optional weight : torch.Tensor, optional
Optional external weight tensor. Optional external weight tensor.
...@@ -120,13 +128,15 @@ class GraphConv(gluon.Block): ...@@ -120,13 +128,15 @@ class GraphConv(gluon.Block):
The output feature The output feature
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().as_in_context(feat.context).astype('float32') degs = graph.out_degrees().as_in_context(feat_src.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf")) degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat_src.ndim - 1)
norm = norm.reshape(shp) norm = norm.reshape(shp)
feat = feat * norm feat_src = feat_src * norm
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
...@@ -134,19 +144,19 @@ class GraphConv(gluon.Block): ...@@ -134,19 +144,19 @@ class GraphConv(gluon.Block):
' module has defined its own weight parameter. Please' ' module has defined its own weight parameter. Please'
' create the module with flag weight=False.') ' create the module with flag weight=False.')
else: else:
weight = self.weight.data(feat.context) weight = self.weight.data(feat_src.context)
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat = mx.nd.dot(feat, weight) feat_src = mx.nd.dot(feat_src, weight)
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata.pop('h') rst = graph.dstdata.pop('h')
...@@ -154,13 +164,13 @@ class GraphConv(gluon.Block): ...@@ -154,13 +164,13 @@ class GraphConv(gluon.Block):
rst = mx.nd.dot(rst, weight) rst = mx.nd.dot(rst, weight)
if self._norm != 'none': if self._norm != 'none':
degs = graph.in_degrees().as_in_context(feat.context).astype('float32') degs = graph.in_degrees().as_in_context(feat_dst.context).astype('float32')
degs = mx.nd.clip(degs, a_min=1, a_max=float("inf")) degs = mx.nd.clip(degs, a_min=1, a_max=float("inf"))
if self._norm == 'both': if self._norm == 'both':
norm = mx.nd.power(degs, -0.5) norm = mx.nd.power(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat_dst.ndim - 1)
norm = norm.reshape(shp) norm = norm.reshape(shp)
rst = rst * norm rst = rst * norm
......
...@@ -86,8 +86,9 @@ class SAGEConv(nn.Block): ...@@ -86,8 +86,9 @@ class SAGEConv(nn.Block):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : mxnet.NDArray or pair of mxnet.NDArray feat : mxnet.NDArray or pair of mxnet.NDArray
If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where If a single tensor is given, it represents the input feature of shape
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
......
...@@ -6,6 +6,7 @@ from torch.nn import init ...@@ -6,6 +6,7 @@ from torch.nn import init
from .... import function as fn from .... import function as fn
from ....base import DGLError from ....base import DGLError
from ....utils import expand_as_pair
# pylint: disable=W0235 # pylint: disable=W0235
class GraphConv(nn.Module): class GraphConv(nn.Module):
...@@ -115,8 +116,15 @@ class GraphConv(nn.Module): ...@@ -115,8 +116,15 @@ class GraphConv(nn.Module):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor or pair of torch.Tensor
The input feature If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Note that in the special case of graph convolutional networks, if a pair of
tensors is given, the latter element will not participate in computation.
weight : torch.Tensor, optional weight : torch.Tensor, optional
Optional external weight tensor. Optional external weight tensor.
...@@ -126,12 +134,14 @@ class GraphConv(nn.Module): ...@@ -126,12 +134,14 @@ class GraphConv(nn.Module):
The output feature The output feature
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
if self._norm == 'both': if self._norm == 'both':
degs = graph.out_degrees().to(feat.device).float().clamp(min=1) degs = graph.out_degrees().to(feat_src.device).float().clamp(min=1)
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat_src.dim() - 1)
norm = th.reshape(norm, shp) norm = th.reshape(norm, shp)
feat = feat * norm feat_src = feat_src * norm
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
...@@ -144,14 +154,14 @@ class GraphConv(nn.Module): ...@@ -144,14 +154,14 @@ class GraphConv(nn.Module):
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat = th.matmul(feat, weight) feat_src = th.matmul(feat_src, weight)
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
...@@ -159,12 +169,12 @@ class GraphConv(nn.Module): ...@@ -159,12 +169,12 @@ class GraphConv(nn.Module):
rst = th.matmul(rst, weight) rst = th.matmul(rst, weight)
if self._norm != 'none': if self._norm != 'none':
degs = graph.in_degrees().to(feat.device).float().clamp(min=1) degs = graph.in_degrees().to(feat_dst.device).float().clamp(min=1)
if self._norm == 'both': if self._norm == 'both':
norm = th.pow(degs, -0.5) norm = th.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.dim() - 1) shp = norm.shape + (1,) * (feat_dst.dim() - 1)
norm = th.reshape(norm, shp) norm = th.reshape(norm, shp)
rst = rst * norm rst = rst * norm
......
...@@ -103,8 +103,9 @@ class SAGEConv(nn.Module): ...@@ -103,8 +103,9 @@ class SAGEConv(nn.Module):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : torch.Tensor or pair of torch.Tensor feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where If a torch.Tensor is given, it represents the input feature of shape
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
......
...@@ -5,6 +5,7 @@ from tensorflow.keras import layers ...@@ -5,6 +5,7 @@ from tensorflow.keras import layers
import numpy as np import numpy as np
from .... import function as fn from .... import function as fn
from ....utils import expand_as_pair
# pylint: disable=W0235 # pylint: disable=W0235
...@@ -112,8 +113,14 @@ class GraphConv(layers.Layer): ...@@ -112,8 +113,14 @@ class GraphConv(layers.Layer):
---------- ----------
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor feat : tf.Tensor or pair of tf.Tensor
The input feature If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Note that in the special case of graph convolutional networks, if a pair of
tensors is given, the latter element will not participate in computation.
weight : torch.Tensor, optional weight : torch.Tensor, optional
Optional external weight tensor. Optional external weight tensor.
...@@ -123,14 +130,16 @@ class GraphConv(layers.Layer): ...@@ -123,14 +130,16 @@ class GraphConv(layers.Layer):
The output feature The output feature
""" """
with graph.local_scope(): with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat)
if self._norm == 'both': if self._norm == 'both':
degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32), degs = tf.clip_by_value(tf.cast(graph.out_degrees(), tf.float32),
clip_value_min=1, clip_value_min=1,
clip_value_max=np.inf) clip_value_max=np.inf)
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat_src.ndim - 1)
norm = tf.reshape(norm, shp) norm = tf.reshape(norm, shp)
feat = feat * norm feat_src = feat_src * norm
if weight is not None: if weight is not None:
if self.weight is not None: if self.weight is not None:
...@@ -143,14 +152,14 @@ class GraphConv(layers.Layer): ...@@ -143,14 +152,14 @@ class GraphConv(layers.Layer):
if self._in_feats > self._out_feats: if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation. # mult W first to reduce the feature size for aggregation.
if weight is not None: if weight is not None:
feat = tf.matmul(feat, weight) feat_src = tf.matmul(feat_src, weight)
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
else: else:
# aggregate first then mult W # aggregate first then mult W
graph.srcdata['h'] = feat graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src(src='h', out='m'), graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h')) fn.sum(msg='m', out='h'))
rst = graph.dstdata['h'] rst = graph.dstdata['h']
...@@ -165,7 +174,7 @@ class GraphConv(layers.Layer): ...@@ -165,7 +174,7 @@ class GraphConv(layers.Layer):
norm = tf.pow(degs, -0.5) norm = tf.pow(degs, -0.5)
else: else:
norm = 1.0 / degs norm = 1.0 / degs
shp = norm.shape + (1,) * (feat.ndim - 1) shp = norm.shape + (1,) * (feat_dst.ndim - 1)
norm = tf.reshape(norm, shp) norm = tf.reshape(norm, shp)
rst = rst * norm rst = rst * norm
......
...@@ -89,8 +89,9 @@ class SAGEConv(layers.Layer): ...@@ -89,8 +89,9 @@ class SAGEConv(layers.Layer):
graph : DGLGraph graph : DGLGraph
The graph. The graph.
feat : tf.Tensor or pair of tf.Tensor feat : tf.Tensor or pair of tf.Tensor
If a single tensor is given, the input feature of shape :math:`(N, D_{in})` where If a single tensor is given, it represents the input feature of shape
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. :math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of tensors are given, the pair must contain two tensors of shape If a pair of tensors are given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
......
...@@ -2,3 +2,9 @@ ...@@ -2,3 +2,9 @@
from .randomwalks import * from .randomwalks import *
from .pinsage import * from .pinsage import *
from .neighbor import * from .neighbor import *
from .dataloader import *
from .. import backend as F
if F.get_preferred_backend() == 'pytorch':
from .pytorch import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment