Unverified Commit a9520f71 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model][Sampler] GraphSAGE model, bipartite graph conversion & remove edges API (#1297)

* remove edge and to bipartite and graphsage with sampling

* fixes

* fixes

* fixes

* reenable multigpu training

* fixes

* compatibility in DGLGraph

* rename to compact_as_bipartite

* bugfix

* lint

* add offline inference

* skip GPU tests

* fix

* addresses comments

* fix

* fix

* fix

* more tests

* more docs and unit tests

* workaround for empty slice on empty data
parent ce6e19f2
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
#### Neighbor sampler
class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts
def sample_blocks(self, seeds):
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
class SAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(
in_feats, n_hidden, 'mean', feat_drop=dropout, activation=activation))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(
n_hidden, n_hidden, 'mean', feat_drop=dropout, activation=activation))
self.layers.append(dglnn.SAGEConv(
n_hidden, n_classes, 'mean', feat_drop=dropout))
def forward(self, blocks, x):
h = x
for layer, block in zip(self.layers, blocks):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
return h
def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = th.arange(g.number_of_nodes())
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
for start in tqdm.trange(0, len(nodes), batch_size):
end = start + batch_size
batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
induced_nodes = block.srcdata[dgl.NID]
h = x[induced_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h = layer(block, (h, h_dst))
y[start:end] = h.cpu()
x = y
return y
#### Miscellaneous functions
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
"""
Evaluate the model on the validation set specified by ``val_mask``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
model.train()
return compute_acc(pred[val_mask], labels[val_mask])
def load_subtensor(g, labels, seeds, induced_nodes, dev_id):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][induced_nodes].to(dev_id)
batch_labels = labels[seeds].to(dev_id)
return batch_inputs, batch_labels
#### Entry point
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2
# Start up distributed training, if enabled.
dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
th.cuda.set_device(dev_id)
# Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
# Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id]
# Create sampler
sampler = NeighborSampler(g, [args.fan_out] * args.num_layers)
# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout)
model = model.to(dev_id)
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Training loop
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
tic = time.time()
train_nid_batches = train_nid[th.randperm(len(train_nid))]
n_batches = (len(train_nid_batches) + args.batch_size - 1) // args.batch_size
for step in range(n_batches):
seeds = train_nid_batches[step * args.batch_size:(step+1) * args.batch_size]
if proc_id == 0:
tic_step = time.time()
# Sample blocks for message propagation
blocks = sampler.sample_blocks(seeds)
induced_nodes = blocks[0].srcdata[dgl.NID]
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, induced_nodes, dev_id)
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()
if n_gpus > 1:
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= n_gpus
optimizer.step()
if proc_id == 0:
iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))
if step % args.log_every == 0 and proc_id == 0:
acc = compute_acc(batch_pred, batch_labels)
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:])))
if n_gpus > 1:
th.distributed.barrier()
toc = time.time()
if proc_id == 0:
print('Epoch Time(s): {:.4f}'.format(toc - tic))
if epoch >= 5:
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0)
print('Eval Acc {:.4f}'.format(eval_acc))
if n_gpus > 1:
th.distributed.barrier()
if proc_id == 0:
print('Avg epoch time: {}'.format(avg / (epoch - 4)))
if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=int, default=10)
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
args = argparser.parse_args()
devices = list(map(int, args.gpu.split(',')))
n_gpus = len(devices)
# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
# Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g
if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
p.join()
......@@ -38,22 +38,22 @@ typedef NDArray TypeArray;
* \brief Sparse format.
*/
enum class SparseFormat {
ANY = 0,
COO = 1,
CSR = 2,
CSC = 3
kAny = 0,
kCOO = 1,
kCSR = 2,
kCSC = 3
};
// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo")
return SparseFormat::COO;
return SparseFormat::kCOO;
else if (name == "csr")
return SparseFormat::CSR;
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::CSC;
return SparseFormat::kCSC;
else
return SparseFormat::ANY;
return SparseFormat::kAny;
}
// Sparse matrix object that is exposed to python API.
......@@ -328,7 +328,7 @@ struct CSRMatrix {
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR), num_rows,
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}
......@@ -408,7 +408,7 @@ struct COOMatrix {
// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::COO), num_rows,
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}
......@@ -548,6 +548,13 @@ bool CSRHasDuplicate(CSRMatrix csr);
*/
void CSRSort_(CSRMatrix* csr);
/*!
* \brief Remove entries from CSR matrix by entry indices (data indices)
* \return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR
* entries.
*/
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
......@@ -723,6 +730,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
*/
COOMatrix COOSort(COOMatrix mat, bool sort_column = false);
/*!
* \brief Remove entries from COO matrix by entry indices (data indices)
* \return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);
/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
......
......@@ -570,7 +570,18 @@ HeteroGraphPtr CreateHeteroGraph(
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY);
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Create a heterograph from COO input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The COO matrix
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Create a heterograph from CSR input.
......@@ -586,7 +597,45 @@ HeteroGraphPtr CreateFromCOO(
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::ANY);
SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Create a heterograph from CSR input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The CSR matrix
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Create a heterograph from CSC input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param num_src Number of nodes in the source type.
* \param num_dst Number of nodes in the destination type.
* \param indptr Indptr array
* \param indices Indices array
* \param edge_ids Edge ids
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Create a heterograph from CSC input.
* \param num_vtypes Number of vertex types. Must be 1 or 2.
* \param mat The CSC matrix
* \param restrict_format Sparse format for storing this graph.
* \return A heterograph pointer.
*/
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
/*!
* \brief Extract the subgraph of the in edges of the given nodes.
......
......@@ -22,7 +22,7 @@ namespace transform {
* outbound edges.
*
* The graphs should have identical node ID space (i.e. should have the same set of nodes,
* including types and IDs) and metagraph.
* including types and IDs).
*
* \param graphs The list of graphs.
* \param always_preserve The list of nodes to preserve regardless of whether the inbound
......@@ -36,6 +36,48 @@ CompactGraphs(
const std::vector<HeteroGraphPtr> &graphs,
const std::vector<IdArray> &always_preserve);
/*!
* \brief Convert a graph into a bipartite-structured graph for message passing.
*
* Specifically, we create one node type \c ntype_l on the "left" side and another
* node type \c ntype_r on the "right" side for each node type \c ntype. The nodes of
* type \c ntype_r would contain the nodes designated by the caller, and node type
* \c ntype_l would contain the nodes that has an edge connecting to one of the
* designated nodes.
*
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
*
* This function is often used for constructing a series of dependency graphs for
* multi-layer message passing, where we first construct a series of frontier graphs
* on the original node space, and run the following to get the bipartite graph needed
* for message passing with each GNN layer:
*
* <code>
* bipartites = [None] * len(num_layers)
* for l in reversed(range(len(layers))):
* bipartites[l], seeds = to_bipartite(frontier[l], seeds)
* x = graph.ndata["h"][seeds]
* for g, layer in zip(bipartites, layers):
* x_src = x
* x_dst = x[:len(g.dsttype)]
* x = sageconv(g, (x_src, x_dst))
* output = x
* </code>
*
* \param graph The graph.
* \param rhs_nodes Designated nodes that would appear on the right side.
*
* \return A triplet containing
* * The bipartite-structured graph,
* * The induced node from the left side for each graph,
* * The induced edges.
*
* \note For each node type \c ntype, the nodes in rhs_nodes[ntype] would always
* appear first in the nodes of type \c ntype_l in the new graph.
*/
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes);
/*!
* \brief Convert a multigraph to a simple graph.
*
......@@ -67,6 +109,18 @@ CompactGraphs(
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToSimpleGraph(const HeteroGraphPtr graph);
/*!
* \brief Remove edges from a graph.
*
* \param graph The graph.
* \param eids The edge IDs to remove per edge type.
*
* \return A pair of the graph with edges removed, as well as the edge ID mapping from
* the original graph to the new graph per edge type.
*/
std::pair<HeteroGraphPtr, std::vector<IdArray>>
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids);
}; // namespace transform
}; // namespace dgl
......
......@@ -190,7 +190,10 @@ def repeat(input, repeats, dim):
def gather_row(data, row_index):
# MXNet workaround for empty row index
if len(row_index) == 0:
return data[0:0]
if data.shape[0] == 0:
return data
else:
return data[0:0]
if isinstance(row_index, nd.NDArray):
return nd.take(data, row_index)
......
......@@ -859,6 +859,8 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False,
If True, the entries in the sparse matrix are treated as edge IDs.
Otherwise, the entries are ignored and edges will be added in
(source, destination) order.
Note that this option only affects CSR matrices; COO matrices' rows and cols
are always assumed to be ordered by edge ID already.
validate : bool, optional
If True, checks if node IDs are within range.
restrict_format : 'any', 'coo', 'csr', 'csc', optional
......
......@@ -1795,6 +1795,16 @@ class DGLGraph(DGLBaseGraph):
"""
return self.nodes[:].data
@property
def srcdata(self):
"""Compatibility interface with heterogeneous graphs; identical to ``ndata``"""
return self.ndata
@property
def dstdata(self):
"""Compatibility interface with heterogeneous graphs; identical to ``ndata``"""
return self.ndata
@property
def edges(self):
"""Return a edges view that can used to set/get feature data.
......
......@@ -363,6 +363,24 @@ class DGLHeteroGraph(object):
"""
return self._canonical_etypes
@property
def ntype(self):
"""Return the node type if the graph has only one node type."""
assert len(self.ntypes) == 1, "The graph has more than one node type."
return self.ntypes[0]
@property
def srctype(self):
"""Return the source node type if the graph has only one edge type."""
assert len(self.etypes) == 1, "The graph has more than one edge type."
return self.canonical_etypes[0][0]
@property
def dsttype(self):
"""Return the destination node type if the graph has only one edge type."""
assert len(self.etypes) == 1, "The graph has more than one edge type."
return self.canonical_etypes[0][2]
@property
def metagraph(self):
"""Return the metagraph as networkx.MultiDiGraph.
......@@ -542,6 +560,68 @@ class DGLHeteroGraph(object):
"""
return HeteroNodeDataView(self, None, ALL)
@property
def srcdata(self):
"""Return the data view of all source nodes.
**Only works if the graph has only one edge type.**
Examples
--------
The following example uses PyTorch backend.
To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.srcdata['h'] = torch.zeros(2, 5)
This is equivalent to
>>> g.nodes['user'].data['h'] = torch.zeros(2, 5)
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
See Also
--------
nodes
"""
assert len(self.etypes) == 1, "Graph has more than one edge type."
srctype = self.canonical_etypes[0][0]
return HeteroNodeDataView(self, srctype, ALL)
@property
def dstdata(self):
"""Return the data view of all destination nodes.
**Only works if the graph has only one edge type.**
Examples
--------
The following example uses PyTorch backend.
To set features of all source nodes in a graph with only one edge type:
>>> g = dgl.bipartite([(0, 1), (1, 2)], 'user', 'plays', 'game')
>>> g.dstdata['h'] = torch.zeros(3, 5)
This is equivalent to
>>> g.nodes['game'].data['h'] = torch.zeros(3, 5)
Notes
-----
This is identical to :any:`DGLHeteroGraph.ndata` if the graph is homogeneous.
See Also
--------
nodes
"""
assert len(self.etypes) == 1, "Graph has more than one edge type."
dsttype = self.canonical_etypes[0][2]
return HeteroNodeDataView(self, dsttype, ALL)
@property
def edges(self):
"""Return an edge view that can be used to set/get feature
......
"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
from numbers import Integral
import torch
from torch import nn
from torch.nn import functional as F
......@@ -21,8 +23,16 @@ class SAGEConv(nn.Module):
Parameters
----------
in_feats : int
in_feats : int, or pair of ints
Input feature size.
If the layer is to be applied on a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
If aggregator type is ``gcn``, the feature size of source and destination nodes
are required to be the same.
out_feats : int
Output feature size.
feat_drop : float
......@@ -46,7 +56,15 @@ class SAGEConv(nn.Module):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_feats = in_feats
if isinstance(in_feats, tuple):
self._in_src_feats = in_feats[0]
self._in_dst_feats = in_feats[1]
elif isinstance(in_feats, Integral):
self._in_src_feats = self._in_dst_feats = in_feats
else:
raise TypeError('in_feats must be either int or pair of ints')
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
......@@ -54,12 +72,12 @@ class SAGEConv(nn.Module):
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(in_feats, in_feats)
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(in_feats, in_feats, batch_first=True)
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(in_feats, out_feats, bias=bias)
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
def reset_parameters(self):
......@@ -80,8 +98,8 @@ class SAGEConv(nn.Module):
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_feats)),
m.new_zeros((1, batch_size, self._in_feats)))
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
......@@ -92,9 +110,11 @@ class SAGEConv(nn.Module):
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature of shape :math:`(N, D_{in})` where :math:`D_{in}`
is size of input feature, :math:`N` is the number of nodes.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Returns
-------
......@@ -103,29 +123,38 @@ class SAGEConv(nn.Module):
is size of output feature.
"""
graph = graph.local_var()
feat = self.feat_drop(feat)
h_self = feat
if torch.is_tensor(feat):
feat_src = feat_dst = self.feat_drop(feat)
else:
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
h_self = feat_dst
if self._aggre_type == 'mean':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().float()
degs = degs.to(feat.device)
h_neigh = (graph.ndata['neigh'] + graph.ndata['h']) / (degs.unsqueeze(-1) + 1)
degs = degs.to(feat_dst.device)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'pool':
graph.ndata['h'] = F.relu(self.fc_pool(feat))
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'lstm':
graph.ndata['h'] = feat
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.ndata['neigh']
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
......
"""Module for graph transformation utilities."""
from collections.abc import Iterable, Mapping
from collections import defaultdict
import numpy as np
from scipy import sparse
from ._ffi.function import _init_api
......@@ -32,9 +33,11 @@ __all__ = [
'remove_self_loop',
'metapath_reachable_graph',
'compact_graphs',
'to_block',
'to_simple',
'in_subgraph',
'out_subgraph']
'out_subgraph',
'remove_edges']
def pairwise_squared_distance(x):
......@@ -705,6 +708,191 @@ def compact_graphs(graphs, always_preserve=None):
return new_graphs
def to_block(g, rhs_nodes=None, lhs_suffix="_l", rhs_suffix="_r"):
"""Convert a graph into a bipartite-structured "block" for message passing.
Specifically, we create one node type ``ntype_l`` on the "left hand" side and another
node type ``ntype_r`` on the "right hand" side for each node type ``ntype``. The
nodes of type ``ntype_r`` would contain the nodes that have an inbound edge of any type,
while ``ntype_l`` would contain all the nodes on the right hand side, as well as any
nodes that have an outbound edge of any type pointing to any node on the right hand side.
For each relation graph of canonical edge type ``(utype, etype, vtype)``, edges
from node type ``utype`` to node type ``vtype`` are preserved, except that the
source node type and destination node type become ``utype_l`` and ``vtype_r`` in
the new graph. The resulting relation graph would have a canonical edge type
``(utype_l, etype, vtype_r)``.
We refer to such bipartite-structured graphs a **block**.
If ``rhs_nodes`` is given, the right hand side would contain the given nodes.
Otherwise, the right hand side would be determined by DGL via the rules above.
Parameters
----------
graph : DGLHeteroGraph
The graph.
rhs_nodes : Tensor or dict[str, Tensor], optional
Optional nodes that would appear on the right hand side.
If a tensor is given, the graph must have only one node type.
lhs_suffix : str, default "_l"
The suffix attached to all node types on the left hand side.
rhs_suffix : str, default "_r"
The suffix attached to all node types on the right hand side.
Returns
-------
DGLHeteroGraph
The new graph describing the block.
The node IDs induced for each type in both sides would be stored in feature
``dgl.NID``.
The edge IDs induced for each type would be stored in feature ``dgl.EID``.
For each node type ``ntype``, the first few nodes with type ``ntype_l`` are
guaranteed to be identical to the nodes with type ``ntype_r``.
Notes
-----
This function is primarily for creating graph structures for efficient
computation of message passing. See [TODO] for a detailed example.
Examples
--------
Converting a homogeneous graph to a block as described above:
>>> g = dgl.graph([(0, 1), (1, 2), (2, 3)])
>>> block = dgl.to_block(g, torch.LongTensor([3, 2]))
The right hand side nodes would be exactly the same as the ones given: [3, 2].
>>> induced_dst = block.dstdata[dgl.NID]
>>> induced_dst
tensor([3, 2])
The first few nodes of the left hand side nodes would also be exactly the same as
the ones given. The rest of the nodes are the ones necessary for message passing
into nodes 3, 2. This means that the node 1 would be included.
>>> induced_src = block.srcdata[dgl.NID]
>>> induced_src
tensor([3, 2, 1])
We can notice that the first two nodes are identical to the given nodes as well as
the right hand side nodes.
The induced edges can also be obtained by the following:
>>> block.edata[dgl.EID]
tensor([2, 1])
This indicates that edge (2, 3) and (1, 2) are included in the result graph. We can
verify that the first edge in the block indeed maps to the edge (2, 3), and the
second edge in the block indeed maps to the edge (1, 2):
>>> src, dst = block.edges(order='eid')
>>> induced_src[src], induced_dst[dst]
(tensor([2, 1]), tensor([3, 2]))
Converting a heterogeneous graph to a block is similar, except that when specifying
the right hand side nodes, you have to give a dict:
>>> g = dgl.bipartite([(0, 1), (1, 2), (2, 3)], utype='A', vtype='B')
If you don't specify any node of type A on the right hand side, the node type ``A_r``
in the block would have zero nodes.
>>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])})
>>> block.number_of_nodes('A_r')
0
>>> block.number_of_nodes('B_r')
2
>>> block.nodes['B_r'].data[dgl.NID]
tensor([3, 2])
The left hand side would contain all the nodes on the right hand side:
>>> block.nodes['B_l'].data[dgl.NID]
tensor([3, 2])
As well as all the nodes that have connections to the nodes on the right hand side:
>>> block.nodes['A_l'].data[dgl.NID]
tensor([2, 1])
"""
if rhs_nodes is None:
# Find all nodes that appeared as destinations
rhs_nodes = defaultdict(list)
for etype in g.canonical_etypes:
_, dst = g.edges(etype=etype)
rhs_nodes[etype[2]].append(dst)
rhs_nodes = {ntype: F.unique(F.cat(values, 0)) for ntype, values in rhs_nodes.items()}
elif not isinstance(rhs_nodes, Mapping):
# rhs_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1:
raise ValueError(
'Graph has more than one node type; please specify a dict for rhs_nodes.')
rhs_nodes = {g.ntypes[0]: rhs_nodes}
# rhs_nodes is now a dict
rhs_nodes_nd = []
for ntype in g.ntypes:
nodes = rhs_nodes.get(ntype, None)
if nodes is not None:
rhs_nodes_nd.append(F.zerocopy_to_dgl_ndarray(nodes))
else:
rhs_nodes_nd.append(nd.null())
new_graph_index, lhs_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(g._graph, rhs_nodes_nd)
lhs_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd.data) for nodes_nd in lhs_nodes_nd]
rhs_nodes = [F.zerocopy_from_dgl_ndarray(nodes_nd) for nodes_nd in rhs_nodes_nd]
new_ntypes = [ntype + lhs_suffix for ntype in g.ntypes] + \
[ntype + rhs_suffix for ntype in g.ntypes]
new_graph = DGLHeteroGraph(new_graph_index, new_ntypes, g.etypes)
for i, ntype in enumerate(g.ntypes):
new_graph.nodes[ntype + lhs_suffix].data[NID] = lhs_nodes[i]
new_graph.nodes[ntype + rhs_suffix].data[NID] = rhs_nodes[i]
for i, canonical_etype in enumerate(g.canonical_etypes):
induced_edges = F.zerocopy_from_dgl_ndarray(induced_edges_nd[i].data)
utype, etype, vtype = canonical_etype
new_canonical_etype = (utype + lhs_suffix, etype, vtype + rhs_suffix)
new_graph.edges[new_canonical_etype].data[EID] = induced_edges
return new_graph
def remove_edges(g, edge_ids):
"""Return a new graph with given edge IDs removed.
The nodes are preserved.
Parameters
----------
graph : DGLHeteroGraph
The graph
edge_ids : Tensor or dict[etypes, Tensor]
The edge IDs for each edge type.
Returns
-------
DGLHeteroGraph
The new graph.
The edge ID mapping from the new graph to the original graph is stored as
``dgl.EID`` on edge features.
"""
if not isinstance(edge_ids, Mapping):
if len(g.etypes) != 1:
raise ValueError(
"Graph has more than one edge type; specify a dict for edge_id instead.")
edge_ids = {g.canonical_etypes[0]: edge_ids}
edge_ids_nd = [None] * len(g.etypes)
for key, value in edge_ids.items():
edge_ids_nd[g.get_etype_id(key)] = F.zerocopy_to_dgl_ndarray(value)
new_graph_index, induced_eids_nd = _CAPI_DGLRemoveEdges(g._graph, edge_ids_nd)
new_graph = DGLHeteroGraph(new_graph_index, g.ntypes, g.etypes)
for i, canonical_etype in enumerate(g.canonical_etypes):
new_graph.edges[canonical_etype].data[EID] = F.zerocopy_from_dgl_ndarray(
induced_eids_nd[i].data)
return new_graph
def in_subgraph(g, nodes):
"""Extract the subgraph containing only the in edges of the given nodes.
......
......@@ -434,6 +434,14 @@ void CSRSort_(CSRMatrix* csr) {
});
}
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CSRMatrix ret;
ATEN_CSR_SWITCH(csr, XPU, IdType, {
ret = impl::CSRRemove<XPU, IdType>(csr, entries);
});
return ret;
}
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
......@@ -577,6 +585,14 @@ COOMatrix COOSort(COOMatrix mat, bool sort_column) {
return ret;
}
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
COOMatrix ret;
ATEN_COO_SWITCH(coo, XPU, IdType, {
ret = impl::COORemove<XPU, IdType>(coo, entries);
});
return ret;
}
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace) {
COOMatrix ret;
......
......@@ -113,6 +113,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
template <DLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(
......@@ -176,6 +179,9 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
COOMatrix COOSort(COOMatrix mat, bool sort_column);
template <DLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseSampling(
......
......@@ -30,10 +30,16 @@ class IdHashMap {
// Construct the hashmap using the given id array.
// The id array could contain duplicates.
// If the id array has no duplicates, the array will be relabeled to consecutive
// integers starting from 0.
explicit IdHashMap(IdArray ids): filter_(kFilterSize, false) {
oldv2newv_.reserve(ids->shape[0]);
Update(ids);
}
// copy ctor
IdHashMap(const IdHashMap &other) = default;
// Update the hashmap with given id array.
// The id array could contain duplicates.
void Update(IdArray ids) {
......
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_remove.cc
* \brief COO matrix remove entries CPU implementation
*/
#include <dgl/array.h>
#include <utility>
#include <vector>
#include "array_utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
namespace {
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */
template <DLDeviceType XPU, typename IdType>
void COORemoveConsecutive(
COOMatrix coo,
IdArray entries,
std::vector<IdType> *new_rows,
std::vector<IdType> *new_cols,
std::vector<IdType> *new_eids) {
const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0];
const IdType *row_data = static_cast<IdType *>(coo.row->data);
const IdType *col_data = static_cast<IdType *>(coo.col->data);
const IdType *entry_data = static_cast<IdType *>(entries->data);
std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);
std::sort(entry_data_sorted.begin(), entry_data_sorted.end());
int64_t j = 0;
for (int64_t i = 0; i < nnz; ++i) {
if (j < n_entries && entry_data_sorted[j] == i) {
// Move on to the next different entry
while (j < n_entries && entry_data_sorted[j] == i)
++j;
continue;
}
new_rows->push_back(row_data[i]);
new_cols->push_back(col_data[i]);
new_eids->push_back(i);
}
}
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
template <DLDeviceType XPU, typename IdType>
void COORemoveShuffled(
COOMatrix coo,
IdArray entries,
std::vector<IdType> *new_rows,
std::vector<IdType> *new_cols,
std::vector<IdType> *new_eids) {
const int64_t nnz = coo.row->shape[0];
const IdType *row_data = static_cast<IdType *>(coo.row->data);
const IdType *col_data = static_cast<IdType *>(coo.col->data);
const IdType *eid_data = static_cast<IdType *>(coo.data->data);
IdHashMap<IdType> eid_map(entries);
for (int64_t i = 0; i < nnz; ++i) {
const IdType eid = eid_data[i];
if (eid_map.Contains(eid))
continue;
new_rows->push_back(row_data[i]);
new_cols->push_back(col_data[i]);
new_eids->push_back(eid);
}
}
}; // namespace
template <DLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0];
if (n_entries == 0)
return coo;
std::vector<IdType> new_rows, new_cols, new_eids;
new_rows.reserve(nnz - n_entries);
new_cols.reserve(nnz - n_entries);
new_eids.reserve(nnz - n_entries);
if (COOHasData(coo))
COORemoveShuffled<XPU, IdType>(coo, entries, &new_rows, &new_cols, &new_eids);
else
// Removing from COO ordered by eid has more efficient implementation.
COORemoveConsecutive<XPU, IdType>(coo, entries, &new_rows, &new_cols, &new_eids);
return COOMatrix(
coo.num_rows, coo.num_cols,
IdArray::FromVector(new_rows),
IdArray::FromVector(new_cols),
IdArray::FromVector(new_eids));
}
template COOMatrix COORemove<kDLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDLCPU, int64_t>(COOMatrix coo, IdArray entries);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file array/cpu/coo_remove.cc
* \brief CSR matrix remove entries CPU implementation
*/
#include <dgl/array.h>
#include <utility>
#include <vector>
#include "array_utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
namespace {
template <DLDeviceType XPU, typename IdType>
void CSRRemoveConsecutive(
CSRMatrix csr,
IdArray entries,
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
const int64_t n_entries = entries->shape[0];
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *entry_data = static_cast<IdType *>(entries->data);
std::vector<IdType> entry_data_sorted(entry_data, entry_data + n_entries);
std::sort(entry_data_sorted.begin(), entry_data_sorted.end());
int64_t k = 0;
new_indptr->push_back(0);
for (int64_t i = 0; i < csr.num_rows; ++i) {
for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
if (k < n_entries && entry_data_sorted[k] == j) {
// Move on to the next different entry
while (k < n_entries && entry_data_sorted[k] == j)
++k;
continue;
}
new_indices->push_back(indices_data[j]);
new_eids->push_back(k);
}
new_indptr->push_back(new_indices->size());
}
}
template <DLDeviceType XPU, typename IdType>
void CSRRemoveShuffled(
CSRMatrix csr,
IdArray entries,
std::vector<IdType> *new_indptr,
std::vector<IdType> *new_indices,
std::vector<IdType> *new_eids) {
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
const IdType *eid_data = static_cast<IdType *>(csr.data->data);
IdHashMap<IdType> eid_map(entries);
new_indptr->push_back(0);
for (int64_t i = 0; i < csr.num_rows; ++i) {
for (IdType j = indptr_data[i]; j < indptr_data[i + 1]; ++j) {
const IdType eid = eid_data ? eid_data[j] : j;
if (eid_map.Contains(eid))
continue;
new_indices->push_back(indices_data[j]);
new_eids->push_back(eid);
}
new_indptr->push_back(new_indices->size());
}
}
}; // namespace
template <DLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
const int64_t nnz = csr.indices->shape[0];
const int64_t n_entries = entries->shape[0];
if (n_entries == 0)
return csr;
std::vector<IdType> new_indptr, new_indices, new_eids;
new_indptr.reserve(nnz - n_entries);
new_indices.reserve(nnz - n_entries);
new_eids.reserve(nnz - n_entries);
if (CSRHasData(csr))
CSRRemoveShuffled<XPU, IdType>(csr, entries, &new_indptr, &new_indices, &new_eids);
else
// Removing from CSR ordered by eid has more efficient implementation
CSRRemoveConsecutive<XPU, IdType>(csr, entries, &new_indptr, &new_indices, &new_eids);
return CSRMatrix(
csr.num_rows, csr.num_cols,
IdArray::FromVector(new_indptr),
IdArray::FromVector(new_indices),
IdArray::FromVector(new_eids));
}
template CSRMatrix CSRRemove<kDLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDLCPU, int64_t>(CSRMatrix csr, IdArray entries);
}; // namespace impl
}; // namespace aten
}; // namespace dgl
......@@ -24,6 +24,13 @@ HeteroGraphPtr CreateFromCOO(
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
......@@ -33,4 +40,27 @@ HeteroGraphPtr CreateFromCSR(
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSC(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
} // namespace dgl
......@@ -18,14 +18,14 @@ HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
states.num_nodes_per_type = graph->NumVerticesPerType();
states.adjs.resize(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::ANY);
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::kAny);
states.adjs[etype] = std::make_shared<SparseMatrix>();
switch (fmt) {
case SparseFormat::COO:
case SparseFormat::kCOO:
*states.adjs[etype] = graph->GetCOOMatrix(etype).ToSparseMatrix();
break;
case SparseFormat::CSR:
case SparseFormat::CSC:
case SparseFormat::kCSR:
case SparseFormat::kCSC:
*states.adjs[etype] = graph->GetCSRMatrix(etype).ToSparseMatrix();
break;
default:
......@@ -47,15 +47,15 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
switch (fmt) {
case SparseFormat::COO:
case SparseFormat::kCOO:
relgraphs[etype] = UnitGraph::CreateFromCOO(
num_vtypes, aten::COOMatrix(*states.adjs[etype]));
break;
case SparseFormat::CSR:
case SparseFormat::kCSR:
relgraphs[etype] = UnitGraph::CreateFromCSR(
num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
break;
case SparseFormat::CSC:
case SparseFormat::kCSC:
default:
LOG(FATAL) << "Unsupported sparse format.";
}
......
......@@ -51,11 +51,11 @@ HeteroSubgraph SampleNeighbors(
induced_edges[etype] = aten::NullArray();
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::kCSR : SparseFormat::kCSC;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::COO:
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
......@@ -65,12 +65,12 @@ HeteroSubgraph SampleNeighbors(
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
}
break;
case SparseFormat::CSR:
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
break;
case SparseFormat::CSC:
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
......@@ -80,7 +80,8 @@ HeteroSubgraph SampleNeighbors(
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data;
}
}
......@@ -125,11 +126,11 @@ HeteroSubgraph SampleNeighborsTopk(
induced_edges[etype] = aten::NullArray();
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::kCSR : SparseFormat::kCSC;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::COO:
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
......@@ -139,12 +140,12 @@ HeteroSubgraph SampleNeighborsTopk(
hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
}
break;
case SparseFormat::CSR:
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
break;
case SparseFormat::CSC:
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
......@@ -154,7 +155,8 @@ HeteroSubgraph SampleNeighborsTopk(
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data;
}
}
......
/*!
* Copyright (c) 2019 by Contributors
* \file graph/transform/remove_edges.cc
* \brief Remove edges.
*/
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h>
#include <vector>
#include <utility>
#include <tuple>
namespace dgl {
using namespace dgl::runtime;
using namespace dgl::aten;
namespace transform {
std::pair<HeteroGraphPtr, std::vector<IdArray>>
RemoveEdges(const HeteroGraphPtr graph, const std::vector<IdArray> &eids) {
std::vector<IdArray> induced_eids;
std::vector<HeteroGraphPtr> rel_graphs;
const int64_t num_etypes = graph->NumEdgeTypes();
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::kCOO);
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
const int num_ntypes_rel = (srctype == dsttype) ? 1 : 2;
HeteroGraphPtr new_rel_graph;
IdArray induced_eids_rel;
if (fmt == SparseFormat::kCOO) {
const COOMatrix &coo = graph->GetCOOMatrix(etype);
const COOMatrix &result = COORemove(coo, eids[etype]);
new_rel_graph = CreateFromCOO(
num_ntypes_rel, result.num_rows, result.num_cols, result.row, result.col);
induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSR) {
const CSRMatrix &csr = graph->GetCSRMatrix(etype);
const CSRMatrix &result = CSRRemove(csr, eids[etype]);
new_rel_graph = CreateFromCSR(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices,
// TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx));
induced_eids_rel = result.data;
} else if (fmt == SparseFormat::kCSC) {
const CSRMatrix &csc = graph->GetCSCMatrix(etype);
const CSRMatrix &result = CSRRemove(csc, eids[etype]);
new_rel_graph = CreateFromCSC(
num_ntypes_rel, result.num_rows, result.num_cols, result.indptr, result.indices,
// TODO(BarclayII): make CSR support null eid array
Range(0, result.indices->shape[0], result.indices->dtype.bits, result.indices->ctx));
induced_eids_rel = result.data;
}
rel_graphs.push_back(new_rel_graph);
induced_eids.push_back(induced_eids_rel);
}
const HeteroGraphPtr new_graph = CreateHeteroGraph(
graph->meta_graph(), rel_graphs, graph->NumVerticesPerType());
return std::make_pair(new_graph, induced_eids);
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLRemoveEdges")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &eids = ListValueToVector<IdArray>(args[1]);
HeteroGraphPtr new_graph;
std::vector<IdArray> induced_eids;
std::tie(new_graph, induced_eids) = RemoveEdges(graph_ref.sptr(), eids);
List<Value> induced_eids_ref;
for (IdArray &array : induced_eids)
induced_eids_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(induced_eids_ref);
*rv = ret;
});
}; // namespace transform
}; // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file graph/transform/to_bipartite.cc
* \brief Convert a graph to a bipartite-structured graph.
*/
#include <dgl/base_heterograph.h>
#include <dgl/transform.h>
#include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/container.h>
#include <vector>
#include <tuple>
// TODO(BarclayII): currently ToBlock depend on IdHashMap<IdType> implementation which
// only works on CPU. Should fix later to make it device agnostic.
#include "../../array/cpu/array_utils.h"
namespace dgl {
using namespace dgl::runtime;
using namespace dgl::aten;
namespace transform {
namespace {
template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes();
std::vector<EdgeArray> edge_arrays(num_etypes);
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type";
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings(rhs_node_mappings); // copy
std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
const EdgeArray edges = graph->InEdges(etype, rhs_nodes[dsttype]);
lhs_node_mappings[srctype].Update(edges.src);
edge_arrays[etype] = edges;
}
const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph = ImmutableGraph::CreateFromCOO(
num_ntypes * 2, etypes.src, new_dst);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(lhs_node_mappings[ntype].Size());
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype)
num_nodes_per_type.push_back(rhs_node_mappings[ntype].Size());
std::vector<HeteroGraphPtr> rel_graphs;
std::vector<IdArray> induced_edges;
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
const IdHashMap<IdType> &lhs_map = lhs_node_mappings[srctype];
const IdHashMap<IdType> &rhs_map = rhs_node_mappings[dsttype];
rel_graphs.push_back(CreateFromCOO(
2, lhs_map.Size(), rhs_map.Size(),
lhs_map.Map(edge_arrays[etype].src, -1),
rhs_map.Map(edge_arrays[etype].dst, -1)));
induced_edges.push_back(edge_arrays[etype].id);
}
const HeteroGraphPtr new_graph = CreateHeteroGraph(
new_meta_graph, rel_graphs, num_nodes_per_type);
std::vector<IdArray> lhs_nodes;
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
lhs_nodes.push_back(lhs_map.Values());
return std::make_tuple(new_graph, lhs_nodes, induced_edges);
}
}; // namespace
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes) {
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ret;
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ret = ToBlock<IdType>(graph, rhs_nodes);
});
return ret;
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]);
HeteroGraphPtr new_graph;
std::vector<IdArray> lhs_nodes;
std::vector<IdArray> induced_edges;
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(graph_ref.sptr(), rhs_nodes);
List<Value> lhs_nodes_ref;
for (IdArray &array : lhs_nodes)
lhs_nodes_ref.push_back(Value(MakeValue(array)));
List<Value> induced_edges_ref;
for (IdArray &array : induced_edges)
induced_edges_ref.push_back(Value(MakeValue(array)));
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(new_graph));
ret.push_back(lhs_nodes_ref);
ret.push_back(induced_edges_ref);
*rv = ret;
});
}; // namespace transform
}; // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment