Unverified Commit f370e628 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Feature] add NodeFlow API (#361)

* sample layer subgraphs.

* fix.

* fix.

* add layered subgraph.

* fix lint.

* fix.

* fix tutorial.

* fix.

* remove copy_to_parent.

* add num_layers

* move sampling code to sampler.cc

* fix.

* move subgraph construction out.

* Revert "move subgraph construction out."

This reverts commit 24b3d13b0d8ed5f91847ea75a7674ee8f7d45cff.

* change to NodeFlow.

* use NodeFlow in Python.

* use NodeFlowIndex.

* add node_mapping and edge_mapping.

* remove unnecessary code in SSE tutorial.

* Revert "remove unnecessary code in SSE tutorial."

This reverts commit 093f0413d5fa2e63ca5f80c46c80a126a9fb720c.

* fix tutorial.

* move to node_flow.

* update gcn cv updater.

* import NodeFlow.

* update.

* add demo code for vanilla control variate sampler.

* update.

* update.

* add neighbor sampling.

* return flow offsets.

* update node_flow.

* add test.

* fix sampler.

* fix graph index.

* fix a bug in sampler.

* fix map_to_layer_nid and map_to_flow_eid.

* fix apply_flow.

* remove model code.

* implement flow_compute.

* fix a bug.

* reverse the csr physically.

* add mini-batch test.

* add mini batch test.

* update flow_compute.

* add prop_flows

* run on specific nodes.

* test copy

* fix a bug in creating frame in NodeFlow.

* add init gcn_cv_updater.

* fix a minor bug.

* fix gcn_cv_updater.

* fix a bug.

* fix a bug in NodeFlow.

* use new h in gcn_cv_updater.

* add layer_in_degree and layer_out_degree.

* fix gcn_cv_updater for gpu.

* temp fix in NodeFlow for diff context.

* allow enabling/disabling copy back.

* add with-updater option.

* fix a bug in computing degree.

* add with-cv option.

* rename and add comments.

* fix lint complain.

* fix lint.

* avoid assert.

* remove assert.

* fix.

* fix.

* fix.

* fix.

* fix the methods in NodeFlow.

* fix lint.

* update SSE.

* remove gcn_cv_updater.

* correct comments for the schedulers.

* update comment.

* add map_to_nodeflow_nid

* address comment.

* remove duplicated test.

* fix int.

* fix comments.

* fix lint

* fix.

* replace subgraph with NodeFlow.

* move view.

* address comments.

* fix lint.

* fix lint.

* remove static_cast.

* fix docstring.

* fix comments.

* break SampleSubgraph.

* move neighbor sampling to sampler.cc

* fix comments.

* rename.

* split neighbor_list.

* address comments.

* fix.

* remove TODO.
parent 220a1e68
......@@ -268,7 +268,7 @@ def main(args, data):
dur = []
sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, neigh_expand,
neighbor_type='in', num_workers=args.num_parallel_subgraphs, seed_nodes=train_vs,
shuffle=True, return_seed_id=True)
shuffle=True)
if args.cache_subgraph:
sampler = CachedSubgraphLoader(sampler, shuffle=True)
for epoch in range(args.n_epochs):
......@@ -279,7 +279,7 @@ def main(args, data):
start1 = time.time()
for subg, aux_infos in sampler:
seeds = aux_infos['seeds']
subg_seeds = subg.map_to_subgraph_nid(seeds)
subg_seeds = subg.layer_nid(0)
subg.copy_from_parent()
losses = []
......
......@@ -6,6 +6,7 @@
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <string>
#include <vector>
#include <string>
#include <cstdint>
......@@ -369,17 +370,6 @@ class Graph: public GraphInterface {
*/
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const;
/*!
* \brief Sample a subgraph from the seed vertices with neighbor sampling.
* The neighbors are sampled with a uniform distribution.
* \return a subgraph
*/
virtual SampledSubgraph NeighborUniformSample(IdArray seeds, const std::string &neigh_type,
int num_hops, int expand_factor) const {
LOG(FATAL) << "NeighborUniformSample isn't supported in mutable graph";
return SampledSubgraph();
}
protected:
friend class GraphOp;
/*! \brief Internal edge list type */
......
......@@ -20,7 +20,7 @@ typedef dgl::runtime::NDArray BoolArray;
typedef dgl::runtime::NDArray IntArray;
struct Subgraph;
struct SampledSubgraph;
struct NodeFlow;
/*!
* \brief This class references data in std::vector.
......@@ -332,14 +332,6 @@ class GraphInterface {
* \return a vector of IdArrays.
*/
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;
/*!
* \brief Sample a subgraph from the seed vertices with neighbor sampling.
* The neighbors are sampled with a uniform distribution.
* \return a subgraph
*/
virtual SampledSubgraph NeighborUniformSample(IdArray seeds, const std::string &neigh_type,
int num_hops, int expand_factor) const = 0;
};
/*! \brief Subgraph data structure */
......@@ -358,21 +350,6 @@ struct Subgraph {
IdArray induced_edges;
};
/*!
* \brief When we sample a subgraph, we need to store extra information,
* such as the layer Ids of the vertices and the sampling probability.
*/
struct SampledSubgraph: public Subgraph {
/*!
* \brief the layer of a sampled vertex in the subgraph.
*/
IdArray layer_ids;
/*!
* \brief the probability that a vertex is sampled.
*/
runtime::NDArray sample_prob;
};
} // namespace dgl
#endif // DGL_GRAPH_INTERFACE_H_
......@@ -56,6 +56,11 @@ class ImmutableGraph: public GraphInterface {
return indices.size();
}
/* This gets the sum of vertex degrees in the range. */
uint64_t GetDegree(dgl_id_t start, dgl_id_t end) const {
return indptr[end] - indptr[start];
}
uint64_t GetDegree(dgl_id_t vid) const {
return indptr[vid + 1] - indptr[vid];
}
......@@ -456,14 +461,6 @@ class ImmutableGraph: public GraphInterface {
return gptr;
}
/*!
* \brief Sample a subgraph from the seed vertices with neighbor sampling.
* The neighbors are sampled with a uniform distribution.
* \return a subgraph
*/
SampledSubgraph NeighborUniformSample(IdArray seeds, const std::string &neigh_type,
int num_hops, int expand_factor) const;
/*!
* \brief Get the adjacency matrix of the graph.
*
......@@ -475,10 +472,6 @@ class ImmutableGraph: public GraphInterface {
*/
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const;
protected:
DGLIdIters GetInEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
DGLIdIters GetOutEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
/*
* The immutable graph may only contain one of the CSRs (e.g., the sampled subgraphs).
* When we get in csr or out csr, we try to get the one cached in the structure.
......@@ -503,6 +496,10 @@ class ImmutableGraph: public GraphInterface {
}
}
protected:
DGLIdIters GetInEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
DGLIdIters GetOutEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
/*!
* \brief Get the CSR array that represents the in-edges.
* This method copies data from std::vector to IdArray.
......@@ -517,10 +514,6 @@ class ImmutableGraph: public GraphInterface {
*/
CSRArray GetOutCSRArray() const;
SampledSubgraph SampleSubgraph(IdArray seed_arr, const float* probability,
const std::string &neigh_type,
int num_hops, size_t num_neighbor) const;
/*!
* \brief Compact a subgraph.
* In a sampled subgraph, the vertex Id is still in the ones in the original graph.
......
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/sampler.h
* \brief DGL sampler header.
*/
#ifndef DGL_SAMPLER_H_
#define DGL_SAMPLER_H_
#include "graph_interface.h"
namespace dgl {
class ImmutableGraph;
/*!
* \brief A NodeFlow graph stores the sampling results for a sampler that samples
* nodes/edges in layers.
*
* We store multiple layers of the sampling results in a single graph, which results
* in a more compact format. We store extra information,
* such as the node and edge mapping from the NodeFlow graph to the parent graph.
*/
struct NodeFlow {
/*! \brief The graph. */
GraphPtr graph;
/*!
* \brief the offsets of each layer.
*/
IdArray layer_offsets;
/*!
* \brief the offsets of each flow.
*/
IdArray flow_offsets;
/*!
* \brief The node mapping from the NodeFlow graph to the parent graph.
*/
IdArray node_mapping;
/*!
* \brief The edge mapping from the NodeFlow graph to the parent graph.
*/
IdArray edge_mapping;
};
class SamplerOp {
public:
/*!
* \brief Sample a graph from the seed vertices with neighbor sampling.
* The neighbors are sampled with a uniform distribution.
*
* \param graphs A graph for sampling.
* \param seeds the nodes where we should start to sample.
* \param edge_type the type of edges we should sample neighbors.
* \param num_hops the number of hops to sample neighbors.
* \param expand_factor the max number of neighbors to sample.
* \return a NodeFlow graph.
*/
static NodeFlow NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds,
const std::string &edge_type,
int num_hops, int expand_factor);
};
} // namespace dgl
#endif // DGL_SAMPLER_H_
# This file contains subgraph samplers.
# This file contains NodeFlow samplers.
import sys
import numpy as np
......@@ -7,7 +7,7 @@ import random
import traceback
from ... import utils
from ...subgraph import DGLSubGraph
from ...node_flow import NodeFlow
from ... import backend as F
try:
import Queue as queue
......@@ -22,7 +22,7 @@ class NSSubgraphLoader(object):
shuffle=False, num_workers=1, return_seed_id=False):
self._g = g
if not g._graph.is_readonly():
raise NotImplementedError("subgraph loader only support read-only graphs.")
raise NotImplementedError("NodeFlow loader only support read-only graphs.")
self._batch_size = batch_size
self._expand_factor = expand_factor
self._num_hops = num_hops
......@@ -39,27 +39,26 @@ class NSSubgraphLoader(object):
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._num_workers = num_workers
self._neighbor_type = neighbor_type
self._subgraphs = []
self._nflows = []
self._seed_ids = []
self._subgraph_idx = 0
self._nflow_idx = 0
def _prefetch(self):
seed_ids = []
num_nodes = len(self._seed_nodes)
for i in range(self._num_workers):
start = self._subgraph_idx * self._batch_size
start = self._nflow_idx * self._batch_size
# if we have visited all nodes, don't do anything.
if start >= num_nodes:
break
end = min((self._subgraph_idx + 1) * self._batch_size, num_nodes)
end = min((self._nflow_idx + 1) * self._batch_size, num_nodes)
seed_ids.append(utils.toindex(self._seed_nodes[start:end]))
self._subgraph_idx += 1
self._nflow_idx += 1
sgi = self._g._graph.neighbor_sampling(seed_ids, self._expand_factor,
self._num_hops, self._neighbor_type,
self._node_prob)
subgraphs = [DGLSubGraph(self._g, i.induced_nodes, i.induced_edges, \
i) for i in sgi]
self._subgraphs.extend(subgraphs)
nflows = [NodeFlow(self._g, i) for i in sgi]
self._nflows.extend(nflows)
if self._return_seed_id:
self._seed_ids.extend(seed_ids)
......@@ -67,17 +66,17 @@ class NSSubgraphLoader(object):
return self
def __next__(self):
# If we don't have prefetched subgraphs, let's prefetch them.
if len(self._subgraphs) == 0:
# If we don't have prefetched NodeFlows, let's prefetch them.
if len(self._nflows) == 0:
self._prefetch()
# At this point, if we still don't have subgraphs, we must have
# iterate all subgraphs and we should stop the iterator now.
if len(self._subgraphs) == 0:
# At this point, if we still don't have NodeFlows, we must have
# iterate all NodeFlows and we should stop the iterator now.
if len(self._nflows) == 0:
raise StopIteration
aux_infos = {}
if self._return_seed_id:
aux_infos['seeds'] = self._seed_ids.pop(0).tousertensor()
return self._subgraphs.pop(0), aux_infos
return self._nflows.pop(0), aux_infos
class _Prefetcher(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
......@@ -199,19 +198,19 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
return_seed_id=False, prefetch=False):
'''Create a sampler that samples neighborhood.
This creates a subgraph data loader that samples subgraphs from the input graph
This creates a NodeFlow loader that samples subgraphs from the input graph
with neighbor sampling. This sampling method is implemented in C and can perform
sampling very efficiently.
A subgraph grows from a seed vertex. It contains sampled neighbors
A NodeFlow grows from a seed vertex. It contains sampled neighbors
of the seed vertex as well as the edges that connect neighbor nodes with
seed nodes. When the number of hops is k (>1), the neighbors are sampled
from the k-hop neighborhood. In this case, the sampled edges are the ones
that connect the source nodes and the sampled neighbor nodes of the source
nodes.
The subgraph loader returns a list of subgraphs and a dictionary of additional
information about the subgraphs. The size of the subgraph list is the number of workers.
The NodeFlow loader returns a list of NodeFlows and a dictionary of additional
information about the NodeFlows. The size of the NodeFlow list is the number of workers.
The dictionary contains:
......@@ -219,8 +218,8 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
Parameters
----------
g: the DGLGraph where we sample subgraphs.
batch_size: The number of subgraphs in a batch.
g: the DGLGraph where we sample NodeFlows.
batch_size: The number of NodeFlows in a batch.
expand_factor: the number of neighbors sampled from the neighbor list
of a vertex. The value of this parameter can be
an integer: indicates the number of neighbors sampled from a neighbor list.
......@@ -234,20 +233,20 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
node_prob: the probability that a neighbor node is sampled.
1D Tensor. None means uniform sampling. Otherwise, the number of elements
should be the same as the number of vertices in the graph.
seed_nodes: a list of nodes where we sample subgraphs from.
seed_nodes: a list of nodes where we sample NodeFlows from.
If it's None, the seed vertices are all vertices in the graph.
shuffle: indicates the sampled subgraphs are shuffled.
num_workers: the number of worker threads that sample subgraphs in parallel.
return_seed_id: indicates whether to return seed ids along with the subgraphs.
shuffle: indicates the sampled NodeFlows are shuffled.
num_workers: the number of worker threads that sample NodeFlows in parallel.
return_seed_id: indicates whether to return seed ids along with the NodeFlows.
The seed Ids are in the parent graph.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns
-------
A subgraph iterator
The iterator returns a list of batched subgraphs and a dictionary of additional
information about the subgraphs.
A NodeFlow iterator
The iterator returns a list of batched NodeFlows and a dictionary of additional
information about the NodeFlows.
'''
loader = NSSubgraphLoader(g, batch_size, expand_factor, num_hops, neighbor_type, node_prob,
seed_nodes, shuffle, num_workers, return_seed_id)
......
......@@ -2,27 +2,24 @@
from __future__ import absolute_import
from collections import defaultdict
import networkx as nx
import dgl
import networkx as nx
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
from .frame import FrameRef, Frame
from .graph_index import create_graph_index
from .runtime import ir, scheduler, Runtime
from . import subgraph
from . import utils
from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph']
class DGLGraph(object):
class DGLBaseGraph(object):
"""Base graph class.
The graph stores nodes, edges and also their features.
DGL graph is always directional. Undirected graph can be represented using
two bi-directional edges.
......@@ -33,472 +30,497 @@ class DGLGraph(object):
of addition, i.e. the first edge being added has an ID of 0, the second
being 1, so on so forth.
Node and edge features are stored as a dictionary from the feature name
to the feature data (in tensor).
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
node_frame : FrameRef, optional
Node feature storage.
edge_frame : FrameRef, optional
Edge feature storage.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
readonly : bool, optional
Whether the graph structure is read-only (default: False).
graph : graph index, optional
Data to initialize graph.
"""
def __init__(self, graph):
self._graph = graph
Examples
--------
Create an empty graph with no nodes and edges.
def number_of_nodes(self):
"""Return the number of nodes in the graph.
>>> G = dgl.DGLGraph()
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
G can be grown in several ways.
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
**Nodes:**
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise.
"""
return self._graph.is_multigraph()
Add N nodes:
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()
>>> G.add_nodes(10) # 10 isolated nodes are added
def number_of_edges(self):
"""Return the number of edges in the graph.
**Edges:**
Returns
-------
int
The number of edges
"""
return self._graph.number_of_edges()
Add one edge at a time,
def has_node(self, vid):
"""Return True if the graph contains node `vid`.
>>> G.add_edge(0, 1)
Identical to `vid in G`.
or multiple edges,
Parameters
----------
vid : int
The node ID.
>>> G.add_edges([1, 2, 3], [3, 4, 5]) # three edges: 1->3, 2->4, 3->5
Returns
-------
bool
True if the node exists
or multiple edges starting from the same node,
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.has_node(0)
True
>>> G.has_node(4)
False
>>> G.add_edges(4, [7, 8, 9]) # three edges: 4->7, 4->8, 4->9
Equivalently,
or multiple edges pointing to the same node,
>>> 0 in G
True
>>> G.add_edges([2, 6, 8], 5) # three edges: 2->5, 6->5, 8->5
See Also
--------
has_nodes
"""
return self._graph.has_node(vid)
or multiple edges using tensor type
def __contains__(self, vid):
"""Return True if the graph contains node `vid`.
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> 0 in G
True
"""
return self._graph.has_node(vid)
>>> import torch as th
>>> G.add_edges(th.tensor([3, 4, 5]), 1) # three edges: 3->1, 4->1, 5->1
def has_nodes(self, vids):
"""Return a 0-1 array ``a`` given the node ID array ``vids``.
NOTE: Removing nodes and edges is not supported by DGLGraph.
``a[i]`` is 1 if the graph contains node ``vids[i]``, 0 otherwise.
**Features:**
Parameters
----------
vid : list or tensor
The array of node IDs.
Both nodes and edges can have feature data. Features are stored as
key/value pair. The key must be hashable while the value must be tensor
type. Features are batched on the first dimension.
Returns
-------
a : tensor
0-1 array indicating existence
Use G.ndata to get/set features for all nodes.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.ndata['x'] = th.zeros((3, 5)) # init 3 nodes with zero vector(len=5)
>>> G.ndata
{'x' : tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])}
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.has_nodes([0, 1, 2, 3, 4])
tensor([1, 1, 1, 0, 0])
Use G.nodes to get/set features for some nodes.
See Also
--------
has_node
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
>>> G.nodes[[0, 2]].data['x'] = th.ones((2, 5))
>>> G.ndata
{'x' : tensor([[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.]])}
def has_edge_between(self, u, v):
"""Return True if the edge (u, v) is in the graph.
Similarly, use G.edata and G.edges to get/set features for edges.
Parameters
----------
u : int
The source node ID.
v : int
The destination node ID.
>>> G.add_edges([0, 1], 2) # 0->2, 1->2
>>> G.edata['y'] = th.zeros((2, 4)) # init 2 edges with zero vector(len=4)
>>> G.edata
{'y' : tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])}
>>> G.edges[1, 2].data['y'] = th.ones((1, 4))
>>> G.edata
{'y' : tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.]])}
Returns
-------
bool
True if the edge is in the graph, False otherwise.
Note that each edge is assigned a unique id equal to its adding
order. So edge 1->2 has id=1. DGL supports directly use edge id
to access edge features.
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edge(0, 1)
>>> G.has_edge_between(0, 1)
True
>>> G.has_edge_between(1, 0)
False
>>> G.edges[0].data['y'] += 2.
>>> G.edata
{'y' : tensor([[2., 2., 2., 2.],
[1., 1., 1., 1.]])}
See Also
--------
has_edges_between
"""
return self._graph.has_edge_between(u, v)
**Message Passing:**
def has_edges_between(self, u, v):
"""Return a 0-1 array `a` given the source node ID array `u` and
destination node ID array `v`.
One common operation for updating node features is message passing,
where the source nodes send messages through edges to the destinations.
With :class:`DGLGraph`, we can do this with :func:`send` and :func:`recv`.
`a[i]` is 1 if the graph contains edge `(u[i], v[i])`, 0 otherwise.
In the example below, the source nodes add 1 to their node features as
the messages and send the messages to the destinations.
Parameters
----------
u : list, tensor
The source node ID array.
v : list, tensor
The destination node ID array.
>>> # Define the function for sending messages.
>>> def send_source(edges): return {'m': edges.src['x'] + 1}
>>> # Set the function defined to be the default message function.
>>> G.register_message_func(send_source)
>>> # Send messages through all edges.
>>> G.send(G.edges())
Returns
-------
a : tensor
0-1 array indicating existence.
Just like you need to go to your mailbox for retrieving mails, the destination
nodes also need to receive the messages and potentially update their features.
Examples
--------
The following example uses PyTorch backend.
>>> # Define a function for summing messages received and replacing the original feature.
>>> def simple_reduce(nodes): return {'x': nodes.mailbox['m'].sum(1)}
>>> # Set the function defined to be the default message reduce function.
>>> G.register_reduce_func(simple_reduce)
>>> # All existing edges have node 2 as the destination.
>>> # Receive the messages for node 2 and update its feature.
>>> G.recv(v=2)
>>> G.ndata
{'x': tensor([[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[3., 3., 3., 3., 3.]])} # 3 = (1 + 1) + (0 + 1)
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
For more examples about message passing, please read our tutorials.
"""
def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
multigraph=False,
readonly=False):
# graph
self._graph = create_graph_index(graph_data, multigraph, readonly)
# node and edge frame
if node_frame is None:
self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
self._node_frame = node_frame
if edge_frame is None:
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
self._msg_frame.set_initializer(init.zero_initializer)
# registered functions
self._message_func = None
self._reduce_func = None
self._apply_node_func = None
self._apply_edge_func = None
Check if (0, 1), (0, 2), (1, 0), (2, 0) exist in the graph above:
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
>>> G.has_edges_between([0, 0, 1, 2], [1, 2, 0, 0])
tensor([1, 1, 0, 0])
See Also
--------
has_edge_between
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(u, v)
return rst.tousertensor()
def predecessors(self, v):
"""Return the predecessors of node `v` in the graph.
Node `u` is a predecessor of `v` if an edge `(u, v)` exist in the
graph.
Parameters
----------
num : int
Number of nodes to be added.
data : dict, optional
Feature data of the added nodes.
v : int
The node.
Notes
-----
If new nodes are added with features, and any of the old nodes
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_n_initializer`` (default filling
with zeros).
Returns
-------
tensor
Array of predecessor node IDs.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> g.add_nodes(2)
>>> g.number_of_nodes()
2
>>> g.add_nodes(3)
>>> g.number_of_nodes()
5
>>> G.add_nodes(3)
>>> G.add_edges([1, 2], [0, 0]) # (1, 0), (2, 0)
>>> G.predecessors(0)
tensor([1, 2])
Adding new nodes with features:
See Also
--------
successors
"""
return self._graph.predecessors(v).tousertensor()
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
def successors(self, v):
"""Return the successors of node `v` in the graph.
>>> import torch as th
>>> g.add_nodes(2, {'x': th.ones(2, 4)}) # default zero initializer
>>> g.ndata['x']
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
Node `u` is a successor of `v` if an edge `(v, u)` exist in the
graph.
Parameters
----------
v : int
The node.
Returns
-------
tensor
Array of successor node IDs.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.successors(0)
tensor([1, 2])
See Also
--------
predecessors
"""
self._graph.add_nodes(num)
if data is None:
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
else:
self._node_frame.append(data)
return self._graph.successors(v).tousertensor()
def add_edge(self, u, v, data=None):
"""Add one new edge between u and v.
def edge_id(self, u, v, force_multi=False):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`.
Parameters
----------
u : int
The source node ID. Must exist in the graph.
The source node ID.
v : int
The destination node ID. Must exist in the graph.
data : dict, optional
Feature data of the added edges.
The destination node ID.
force_multi : bool
If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array.
Notes
-----
If new edges are added with features, and any of the old edges
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_e_initializer`` (default filling
with zeros).
Returns
-------
int or tensor
The edge ID if force_multi == True and the graph is a simple graph.
The edge ID array otherwise.
Examples
--------
The following example uses PyTorch backend.
For simple graphs:
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edge(0, 1)
>>> G.add_node(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.edge_id(0, 2)
1
>>> G.edge_id(0, 1)
0
Adding new edge with features
For multigraphs:
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
>>> G = dgl.DGLGraph(multigraph=True)
>>> G.add_nodes(3)
>>> import torch as th
>>> G.add_edge(0, 2, {'x': th.ones(1, 4)})
>>> G.edges()
(tensor([0, 0]), tensor([1, 2]))
>>> G.edata['x']
tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.]])
>>> G.edges[0, 2].data['x']
tensor([[1., 1., 1., 1.]])
Adding edges (0, 1), (0, 2), (0, 1), (0, 2), so edge ID 0 and 2 both
connect from 0 and 1, while edge ID 1 and 3 both connect from 0 and 2.
>>> G.add_edges([0, 0, 0, 0], [1, 2, 1, 2])
>>> G.edge_id(0, 1)
tensor([0, 2])
See Also
--------
add_edges
edge_ids
"""
self._graph.add_edge(u, v)
if data is None:
# Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(1)
else:
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)
idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0]
def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes
v. A single edge is added between every pair of ``u[i]`` and ``v[i]``.
def edge_ids(self, u, v, force_multi=False):
"""Return all edge IDs between source node array `u` and destination
node array `v`.
Parameters
----------
u : list, tensor
The source node IDs. All nodes must exist in the graph.
The source node ID array.
v : list, tensor
The destination node IDs. All nodes must exist in the graph.
data : dict, optional
Feature data of the added edges.
The destination node ID array.
force_multi : bool
Whether to always treat the graph as a multigraph.
Returns
-------
tensor, or (tensor, tensor, tensor)
If the graph is a simple graph and `force_multi` is False, return
a single edge ID array `e`. `e[i]` is the edge ID between `u[i]`
and `v[i]`.
Otherwise, return three arrays `(eu, ev, e)`. `e[i]` is the ID
of an edge between `eu[i]` and `ev[i]`. All edges between `u[i]`
and `v[i]` are returned.
Notes
-----
If new edges are added with features, and any of the old edges
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_e_initializer`` (default filling
with zeros).
If the graph is a simple graph, `force_multi` is False, and no edge
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Examples
--------
The following example uses PyTorch backend.
For simple graphs:
>>> G = dgl.DGLGraph()
>>> G.add_nodes(4)
>>> G.add_edges([0, 2], [1, 3]) # add edges (0, 1) and (2, 3)
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.edge_ids([0, 0], [2, 1]) # get edge ID of (0, 2) and (0, 1)
>>> G.edge_ids([0, 0], [2, 1])
tensor([1, 0])
Adding new edges with features
For multigraphs
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
>>> G = dgl.DGLGraph(multigraph=True)
>>> G.add_nodes(4)
>>> G.add_edges([0, 0, 0], [1, 1, 2]) # (0, 1), (0, 1), (0, 2)
>>> import torch as th
>>> G.add_edges([1, 3], [2, 0], {'x': th.ones(2, 4)}) # (1, 2), (3, 0)
>>> G.edata['x']
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
Get all edges between (0, 1), (0, 2), (0, 3). Note that there is no
edge between 0 and 3:
>>> G.edge_ids([0, 0, 0], [1, 2, 3])
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([0, 1, 2]))
See Also
--------
add_edge
edge_id
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
num = max(len(u), len(v))
if data is None:
# Initialize feature placeholders if there are features existing
# NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(num)
src, dst, eid = self._graph.edge_ids(u, v)
if force_multi or self.is_multigraph:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)
def clear(self):
"""Remove all nodes and edges, as well as their features, from the
graph.
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(4)
>>> G.add_edges([0, 1, 2, 3], [1, 2, 3, 0])
>>> G.number_of_nodes()
4
>>> G.number_of_edges()
4
>>> G.clear()
>>> G.number_of_nodes()
0
>>> G.number_of_edges()
0
"""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_index = utils.zero_index(0)
self._msg_frame.clear()
def clear_cache(self):
"""Clear all cached graph structures such as adjmat.
return eid.tousertensor()
By default, all graph structure related sparse matrices (e.g. adjmat, incmat)
are cached so they could be reused with the cost of extra memory consumption.
This function can be used to clear the cached matrices if memory is an issue.
"""
self._graph.clear_cache()
def find_edges(self, eid):
"""Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`.
def number_of_nodes(self):
"""Return the number of nodes in the graph.
Parameters
----------
eid : list, tensor
The edge ID array.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
@property
def is_multigraph(self):
"""True if the graph is a multigraph, False otherwise.
"""
return self._graph.is_multigraph()
@property
def is_readonly(self):
"""True if the graph is readonly, False otherwise.
"""
return self._graph.is_readonly()
tensor
The source node ID array.
tensor
The destination node ID array.
def number_of_edges(self):
"""Return the number of edges in the graph.
Examples
--------
The following example uses PyTorch backend.
Returns
-------
int
The number of edges
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.find_edges([0, 2])
(tensor([0, 1]), tensor([1, 2]))
"""
return self._graph.number_of_edges()
def has_node(self, vid):
"""Return True if the graph contains node `vid`.
eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor()
Identical to `vid in G`.
def in_edges(self, v, form='uv'):
"""Return the inbound edges of the node(s).
Parameters
----------
vid : int
The node ID.
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns
-------
bool
True if the node exists
A tuple of Tensors ``(eu, ev, eid)`` if ``form == 'all'``.
``eid[i]`` is the ID of an inbound edge to ``ev[i]`` from ``eu[i]``.
All inbound edges to ``v`` are returned.
A pair of Tensors (eu, ev) if form == 'uv'
``eu[i]`` is the source node of an inbound edge to ``ev[i]``.
All inbound edges to ``v`` are returned.
One Tensor if form == 'eid'
``eid[i]`` is ID of an inbound edge to any of the nodes in ``v``.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.has_node(0)
True
>>> G.has_node(4)
False
Equivalently,
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> 0 in G
True
For a single node:
See Also
--------
has_nodes
"""
return self._graph.has_node(vid)
>>> G.in_edges(2)
(tensor([0, 1]), tensor([2, 2]))
>>> G.in_edges(2, 'all')
(tensor([0, 1]), tensor([2, 2]), tensor([1, 2]))
>>> G.in_edges(2, 'eid')
tensor([1, 2])
def __contains__(self, vid):
"""Return True if the graph contains node `vid`.
For multiple nodes:
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> 0 in G
True
>>> G.in_edges([1, 2])
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.in_edges([1, 2], 'all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
return self._graph.has_node(vid)
def has_nodes(self, vids):
"""Return a 0-1 array ``a`` given the node ID array ``vids``.
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
``a[i]`` is 1 if the graph contains node ``vids[i]``, 0 otherwise.
def out_edges(self, v, form='uv'):
"""Return the outbound edges of the node(s).
Parameters
----------
vid : list or tensor
The array of node IDs.
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns
-------
a : tensor
0-1 array indicating existence
A tuple of Tensors ``(eu, ev, eid)`` if ``form == 'all'``.
``eid[i]`` is the ID of an outbound edge from ``eu[i]`` to ``ev[i]``.
All outbound edges from ``v`` are returned.
A pair of Tensors (eu, ev) if form == 'uv'
``ev[i]`` is the destination node of an outbound edge from ``eu[i]``.
All outbound edges from ``v`` are returned.
One Tensor if form == 'eid'
``eid[i]`` is ID of an outbound edge from any of the nodes in ``v``.
Examples
--------
......@@ -506,103 +528,128 @@ class DGLGraph(object):
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.has_nodes([0, 1, 2, 3, 4])
tensor([1, 1, 1, 0, 0])
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
See Also
--------
has_node
For a single node:
>>> G.out_edges(0)
(tensor([0, 0]), tensor([1, 2]))
>>> G.out_edges(0, 'all')
(tensor([0, 0]), tensor([1, 2]), tensor([0, 1]))
>>> G.out_edges(0, 'eid')
tensor([0, 1])
For multiple nodes:
>>> G.out_edges([0, 1])
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.out_edges([0, 1], 'all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def has_edge_between(self, u, v):
"""Return True if the edge (u, v) is in the graph.
def all_edges(self, form='uv', order=None):
"""Return all the edges.
Parameters
----------
u : int
The source node ID.
v : int
The destination node ID.
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
order : string
The order of the returned edges. Currently support:
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
Returns
-------
bool
True if the edge is in the graph, False otherwise.
A tuple of Tensors (u, v, eid) if form == 'all'
``eid[i]`` is the ID of an edge between ``u[i]`` and ``v[i]``.
All edges are returned.
A pair of Tensors (u, v) if form == 'uv'
An edge exists between ``u[i]`` and ``v[i]``.
If ``n`` edges exist between ``u`` and ``v``, then ``u`` and ``v`` as a pair
will appear ``n`` times.
One Tensor if form == 'eid'
``eid[i]`` is the ID of an edge in the graph.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edge(0, 1)
>>> G.has_edge_between(0, 1)
True
>>> G.has_edge_between(1, 0)
False
See Also
--------
has_edges_between
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.all_edges()
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.all_edges('all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
return self._graph.has_edge_between(u, v)
def has_edges_between(self, u, v):
"""Return a 0-1 array `a` given the source node ID array `u` and
destination node ID array `v`.
src, dst, eid = self._graph.edges(order)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
`a[i]` is 1 if the graph contains edge `(u[i], v[i])`, 0 otherwise.
def in_degree(self, v):
"""Return the in-degree of node ``v``.
Parameters
----------
u : list, tensor
The source node ID array.
v : list, tensor
The destination node ID array.
v : int
The node ID.
Returns
-------
a : tensor
0-1 array indicating existence.
int
The in-degree.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
Check if (0, 1), (0, 2), (1, 0), (2, 0) exist in the graph above:
>>> G.has_edges_between([0, 0, 1, 2], [1, 2, 0, 0])
tensor([1, 1, 0, 0])
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.in_degree(2)
2
See Also
--------
has_edge_between
in_degrees
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges_between(u, v)
return rst.tousertensor()
return self._graph.in_degree(v)
def predecessors(self, v):
"""Return the predecessors of node `v` in the graph.
def in_degrees(self, v=ALL):
"""Return the array `d` of in-degrees of the node array `v`.
Node `u` is a predecessor of `v` if an edge `(u, v)` exist in the
graph.
`d[i]` is the in-degree of node `v[i]`.
Parameters
----------
v : int
The node.
v : list, tensor, optional.
The node ID array. Default is to return the degrees of all the nodes.
Returns
-------
tensor
Array of predecessor node IDs.
d : tensor
The in-degree array.
Examples
--------
......@@ -610,410 +657,333 @@ class DGLGraph(object):
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([1, 2], [0, 0]) # (1, 0), (2, 0)
>>> G.predecessors(0)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.in_degrees([1, 2])
tensor([1, 2])
See Also
--------
successors
in_degree
"""
return self._graph.predecessors(v).tousertensor()
def successors(self, v):
"""Return the successors of node `v` in the graph.
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor()
Node `u` is a successor of `v` if an edge `(v, u)` exist in the
graph.
def out_degree(self, v):
"""Return the out-degree of node `v`.
Parameters
----------
v : int
The node.
The node ID.
Returns
-------
tensor
Array of successor node IDs.
int
The out-degree.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.successors(0)
tensor([1, 2])
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.out_degree(0)
2
See Also
--------
predecessors
out_degrees
"""
return self._graph.successors(v).tousertensor()
return self._graph.out_degree(v)
def edge_id(self, u, v, force_multi=False):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`.
def out_degrees(self, v=ALL):
"""Return the array `d` of out-degrees of the node array `v`.
`d[i]` is the out-degree of node `v[i]`.
Parameters
----------
u : int
The source node ID.
v : int
The destination node ID.
force_multi : bool
If False, will return a single edge ID if the graph is a simple graph.
If True, will always return an array.
v : list, tensor
The node ID array. Default is to return the degrees of all the nodes.
Returns
-------
int or tensor
The edge ID if force_multi == True and the graph is a simple graph.
The edge ID array otherwise.
d : tensor
The out-degree array.
Examples
--------
The following example uses PyTorch backend.
For simple graphs:
>>> G = dgl.DGLGraph()
>>> G.add_node(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.edge_id(0, 2)
1
>>> G.edge_id(0, 1)
0
For multigraphs:
>>> G = dgl.DGLGraph(multigraph=True)
>>> G.add_nodes(3)
Adding edges (0, 1), (0, 2), (0, 1), (0, 2), so edge ID 0 and 2 both
connect from 0 and 1, while edge ID 1 and 3 both connect from 0 and 2.
>>> G.add_edges([0, 0, 0, 0], [1, 2, 1, 2])
>>> G.edge_id(0, 1)
tensor([0, 2])
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.out_degrees([0, 1])
tensor([2, 1])
See Also
--------
edge_ids
out_degree
"""
idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0]
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
def edge_ids(self, u, v, force_multi=False):
"""Return all edge IDs between source node array `u` and destination
node array `v`.
class DGLGraph(DGLBaseGraph):
"""Base graph class.
Parameters
----------
u : list, tensor
The source node ID array.
v : list, tensor
The destination node ID array.
force_multi : bool
Whether to always treat the graph as a multigraph.
The graph stores nodes, edges and also their features.
Returns
-------
tensor, or (tensor, tensor, tensor)
If the graph is a simple graph and `force_multi` is False, return
a single edge ID array `e`. `e[i]` is the edge ID between `u[i]`
and `v[i]`.
Otherwise, return three arrays `(eu, ev, e)`. `e[i]` is the ID
of an edge between `eu[i]` and `ev[i]`. All edges between `u[i]`
and `v[i]` are returned.
DGL graph is always directional. Undirected graph can be represented using
two bi-directional edges.
Notes
-----
If the graph is a simple graph, `force_multi` is False, and no edge
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Nodes are identified by consecutive integers starting from zero.
Examples
--------
The following example uses PyTorch backend.
Edges can be specified by two end points (u, v) or the integer id assigned
when the edges are added. Edge IDs are automatically assigned by the order
of addition, i.e. the first edge being added has an ID of 0, the second
being 1, so on so forth.
For simple graphs:
Node and edge features are stored as a dictionary from the feature name
to the feature data (in tensor).
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0], [1, 2]) # (0, 1), (0, 2)
>>> G.edge_ids([0, 0], [2, 1]) # get edge ID of (0, 2) and (0, 1)
>>> G.edge_ids([0, 0], [2, 1])
tensor([1, 0])
Parameters
----------
graph_data : graph data, optional
Data to initialize graph. Same as networkx's semantics.
node_frame : FrameRef, optional
Node feature storage.
edge_frame : FrameRef, optional
Edge feature storage.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
readonly : bool, optional
Whether the graph structure is read-only (default: False).
For multigraphs
Examples
--------
Create an empty graph with no nodes and edges.
>>> G = dgl.DGLGraph(multigraph=True)
>>> G.add_nodes(4)
>>> G.add_edges([0, 0, 0], [1, 1, 2]) # (0, 1), (0, 1), (0, 2)
>>> G = dgl.DGLGraph()
Get all edges between (0, 1), (0, 2), (0, 3). Note that there is no
edge between 0 and 3:
G can be grown in several ways.
>>> G.edge_ids([0, 0, 0], [1, 2, 3])
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([0, 1, 2]))
**Nodes:**
See Also
--------
edge_id
"""
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(u, v)
if force_multi or self.is_multigraph:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
return eid.tousertensor()
Add N nodes:
def find_edges(self, eid):
"""Given an edge ID array, return the source and destination node ID
array `s` and `d`. `s[i]` and `d[i]` are source and destination node
ID for edge `eid[i]`.
>>> G.add_nodes(10) # 10 isolated nodes are added
Parameters
----------
eid : list, tensor
The edge ID array.
**Edges:**
Returns
-------
tensor
The source node ID array.
tensor
The destination node ID array.
Examples
--------
The following example uses PyTorch backend.
Add one edge at a time,
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.find_edges([0, 2])
(tensor([0, 1]), tensor([1, 2]))
"""
eid = utils.toindex(eid)
src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor()
>>> G.add_edge(0, 1)
def in_edges(self, v, form='uv'):
"""Return the inbound edges of the node(s).
or multiple edges,
Parameters
----------
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
>>> G.add_edges([1, 2, 3], [3, 4, 5]) # three edges: 1->3, 2->4, 3->5
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
or multiple edges starting from the same node,
Returns
-------
A tuple of Tensors ``(eu, ev, eid)`` if ``form == 'all'``.
``eid[i]`` is the ID of an inbound edge to ``ev[i]`` from ``eu[i]``.
All inbound edges to ``v`` are returned.
A pair of Tensors (eu, ev) if form == 'uv'
``eu[i]`` is the source node of an inbound edge to ``ev[i]``.
All inbound edges to ``v`` are returned.
One Tensor if form == 'eid'
``eid[i]`` is ID of an inbound edge to any of the nodes in ``v``.
>>> G.add_edges(4, [7, 8, 9]) # three edges: 4->7, 4->8, 4->9
Examples
--------
The following example uses PyTorch backend.
or multiple edges pointing to the same node,
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.add_edges([2, 6, 8], 5) # three edges: 2->5, 6->5, 8->5
For a single node:
or multiple edges using tensor type
>>> G.in_edges(2)
(tensor([0, 1]), tensor([2, 2]))
>>> G.in_edges(2, 'all')
(tensor([0, 1]), tensor([2, 2]), tensor([1, 2]))
>>> G.in_edges(2, 'eid')
tensor([1, 2])
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
For multiple nodes:
>>> import torch as th
>>> G.add_edges(th.tensor([3, 4, 5]), 1) # three edges: 3->1, 4->1, 5->1
>>> G.in_edges([1, 2])
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.in_edges([1, 2], 'all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
NOTE: Removing nodes and edges is not supported by DGLGraph.
def out_edges(self, v, form='uv'):
"""Return the outbound edges of the node(s).
**Features:**
Parameters
----------
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
Both nodes and edges can have feature data. Features are stored as
key/value pair. The key must be hashable while the value must be tensor
type. Features are batched on the first dimension.
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Use G.ndata to get/set features for all nodes.
Returns
-------
A tuple of Tensors ``(eu, ev, eid)`` if ``form == 'all'``.
``eid[i]`` is the ID of an outbound edge from ``eu[i]`` to ``ev[i]``.
All outbound edges from ``v`` are returned.
A pair of Tensors (eu, ev) if form == 'uv'
``ev[i]`` is the destination node of an outbound edge from ``eu[i]``.
All outbound edges from ``v`` are returned.
One Tensor if form == 'eid'
``eid[i]`` is ID of an outbound edge from any of the nodes in ``v``.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.ndata['x'] = th.zeros((3, 5)) # init 3 nodes with zero vector(len=5)
>>> G.ndata
{'x' : tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])}
Examples
--------
The following example uses PyTorch backend.
Use G.nodes to get/set features for some nodes.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.nodes[[0, 2]].data['x'] = th.ones((2, 5))
>>> G.ndata
{'x' : tensor([[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.]])}
For a single node:
Similarly, use G.edata and G.edges to get/set features for edges.
>>> G.out_edges(0)
(tensor([0, 0]), tensor([1, 2]))
>>> G.out_edges(0, 'all')
(tensor([0, 0]), tensor([1, 2]), tensor([0, 1]))
>>> G.out_edges(0, 'eid')
tensor([0, 1])
>>> G.add_edges([0, 1], 2) # 0->2, 1->2
>>> G.edata['y'] = th.zeros((2, 4)) # init 2 edges with zero vector(len=4)
>>> G.edata
{'y' : tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.]])}
>>> G.edges[1, 2].data['y'] = th.ones((1, 4))
>>> G.edata
{'y' : tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.]])}
For multiple nodes:
Note that each edge is assigned a unique id equal to its adding
order. So edge 1->2 has id=1. DGL supports directly use edge id
to access edge features.
>>> G.out_edges([0, 1])
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.out_edges([0, 1], 'all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
>>> G.edges[0].data['y'] += 2.
>>> G.edata
{'y' : tensor([[2., 2., 2., 2.],
[1., 1., 1., 1.]])}
def all_edges(self, form='uv', order=None):
"""Return all the edges.
**Message Passing:**
Parameters
----------
form : str, optional
The return form. Currently support:
One common operation for updating node features is message passing,
where the source nodes send messages through edges to the destinations.
With :class:`DGLGraph`, we can do this with :func:`send` and :func:`recv`.
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
order : string
The order of the returned edges. Currently support:
In the example below, the source nodes add 1 to their node features as
the messages and send the messages to the destinations.
- 'srcdst' : sorted by their src and dst ids.
- 'eid' : sorted by edge Ids.
- None : the arbitrary order.
>>> # Define the function for sending messages.
>>> def send_source(edges): return {'m': edges.src['x'] + 1}
>>> # Set the function defined to be the default message function.
>>> G.register_message_func(send_source)
>>> # Send messages through all edges.
>>> G.send(G.edges())
Returns
-------
A tuple of Tensors (u, v, eid) if form == 'all'
``eid[i]`` is the ID of an edge between ``u[i]`` and ``v[i]``.
All edges are returned.
A pair of Tensors (u, v) if form == 'uv'
An edge exists between ``u[i]`` and ``v[i]``.
If ``n`` edges exist between ``u`` and ``v``, then ``u`` and ``v`` as a pair
will appear ``n`` times.
One Tensor if form == 'eid'
``eid[i]`` is the ID of an edge in the graph.
Just like you need to go to your mailbox for retrieving mails, the destination
nodes also need to receive the messages and potentially update their features.
Examples
--------
The following example uses PyTorch backend.
>>> # Define a function for summing messages received and replacing the original feature.
>>> def simple_reduce(nodes): return {'x': nodes.mailbox['m'].sum(1)}
>>> # Set the function defined to be the default message reduce function.
>>> G.register_reduce_func(simple_reduce)
>>> # All existing edges have node 2 as the destination.
>>> # Receive the messages for node 2 and update its feature.
>>> G.recv(v=2)
>>> G.ndata
{'x': tensor([[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[3., 3., 3., 3., 3.]])} # 3 = (1 + 1) + (0 + 1)
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.all_edges()
(tensor([0, 0, 1]), tensor([1, 2, 2]))
>>> G.all_edges('all')
(tensor([0, 0, 1]), tensor([1, 2, 2]), tensor([0, 1, 2]))
"""
src, dst, eid = self._graph.edges(order)
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
For more examples about message passing, please read our tutorials.
"""
def __init__(self,
graph_data=None,
node_frame=None,
edge_frame=None,
multigraph=False,
readonly=False):
# graph
super(DGLGraph, self).__init__(create_graph_index(graph_data, multigraph, readonly))
# node and edge frame
if node_frame is None:
self._node_frame = FrameRef(Frame(num_rows=self.number_of_nodes()))
else:
raise DGLError('Invalid form:', form)
self._node_frame = node_frame
if edge_frame is None:
self._edge_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
else:
self._edge_frame = edge_frame
# message indicator:
# if self._msg_index[eid] == 1, then edge eid has message
self._msg_index = utils.zero_index(size=self.number_of_edges())
# message frame
self._msg_frame = FrameRef(Frame(num_rows=self.number_of_edges()))
# set initializer for message frame
self._msg_frame.set_initializer(init.zero_initializer)
# registered functions
self._message_func = None
self._reduce_func = None
self._apply_node_func = None
self._apply_edge_func = None
def in_degree(self, v):
"""Return the in-degree of node ``v``.
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
Parameters
----------
v : int
The node ID.
num : int
Number of nodes to be added.
data : dict, optional
Feature data of the added nodes.
Returns
-------
int
The in-degree.
Notes
-----
If new nodes are added with features, and any of the old nodes
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_n_initializer`` (default filling
with zeros).
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.in_degree(2)
>>> g.add_nodes(2)
>>> g.number_of_nodes()
2
>>> g.add_nodes(3)
>>> g.number_of_nodes()
5
See Also
--------
in_degrees
"""
return self._graph.in_degree(v)
Adding new nodes with features:
def in_degrees(self, v=ALL):
"""Return the array `d` of in-degrees of the node array `v`.
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
`d[i]` is the in-degree of node `v[i]`.
>>> import torch as th
>>> g.add_nodes(2, {'x': th.ones(2, 4)}) # default zero initializer
>>> g.ndata['x']
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
"""
self._graph.add_nodes(num)
if data is None:
# Initialize feature placeholders if there are features existing
self._node_frame.add_rows(num)
else:
self._node_frame.append(data)
def add_edge(self, u, v, data=None):
"""Add one new edge between u and v.
Parameters
----------
v : list, tensor, optional.
The node ID array. Default is to return the degrees of all the nodes.
u : int
The source node ID. Must exist in the graph.
v : int
The destination node ID. Must exist in the graph.
data : dict, optional
Feature data of the added edges.
Returns
-------
d : tensor
The in-degree array.
Notes
-----
If new edges are added with features, and any of the old edges
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_e_initializer`` (default filling
with zeros).
Examples
--------
......@@ -1021,81 +991,131 @@ class DGLGraph(object):
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.in_degrees([1, 2])
tensor([1, 2])
>>> G.add_edge(0, 1)
Adding new edge with features
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
>>> import torch as th
>>> G.add_edge(0, 2, {'x': th.ones(1, 4)})
>>> G.edges()
(tensor([0, 0]), tensor([1, 2]))
>>> G.edata['x']
tensor([[0., 0., 0., 0.],
[1., 1., 1., 1.]])
>>> G.edges[0, 2].data['x']
tensor([[1., 1., 1., 1.]])
See Also
--------
in_degree
add_edges
"""
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
self._graph.add_edge(u, v)
if data is None:
# Initialize feature placeholders if there are features existing
self._edge_frame.add_rows(1)
else:
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor()
self._edge_frame.append(data)
# resize msg_index and msg_frame
self._msg_index = self._msg_index.append_zeros(1)
self._msg_frame.add_rows(1)
def out_degree(self, v):
"""Return the out-degree of node `v`.
def add_edges(self, u, v, data=None):
"""Add multiple edges for list of source nodes u and destination nodes
v. A single edge is added between every pair of ``u[i]`` and ``v[i]``.
Parameters
----------
v : int
The node ID.
u : list, tensor
The source node IDs. All nodes must exist in the graph.
v : list, tensor
The destination node IDs. All nodes must exist in the graph.
data : dict, optional
Feature data of the added edges.
Returns
-------
int
The out-degree.
Notes
-----
If new edges are added with features, and any of the old edges
do not have some of the feature fields, those fields are filled
by initializers defined with ``set_e_initializer`` (default filling
with zeros).
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.out_degree(0)
2
>>> G.add_nodes(4)
>>> G.add_edges([0, 2], [1, 3]) # add edges (0, 1) and (2, 3)
See Also
--------
out_degrees
"""
return self._graph.out_degree(v)
Adding new edges with features
def out_degrees(self, v=ALL):
"""Return the array `d` of out-degrees of the node array `v`.
.. note:: Here we use pytorch syntax for demo. The general idea applies
to other frameworks with minor syntax change (e.g. replace
``torch.tensor`` with ``mxnet.ndarray``).
`d[i]` is the out-degree of node `v[i]`.
>>> import torch as th
>>> G.add_edges([1, 3], [2, 0], {'x': th.ones(2, 4)}) # (1, 2), (3, 0)
>>> G.edata['x']
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
Parameters
----------
v : list, tensor
The node ID array. Default is to return the degrees of all the nodes.
See Also
--------
add_edge
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
num = max(len(u), len(v))
if data is None:
# Initialize feature placeholders if there are features existing
# NOTE: use max due to edge broadcasting syntax
self._edge_frame.add_rows(num)
else:
self._edge_frame.append(data)
# initialize feature placeholder for messages
self._msg_index = self._msg_index.append_zeros(num)
self._msg_frame.add_rows(num)
Returns
-------
d : tensor
The out-degree array.
def clear(self):
"""Remove all nodes and edges, as well as their features, from the
graph.
Examples
--------
The following example uses PyTorch backend.
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edges([0, 0, 1], [1, 2, 2]) # (0, 1), (0, 2), (1, 2)
>>> G.out_degrees([0, 1])
tensor([2, 1])
>>> G.add_nodes(4)
>>> G.add_edges([0, 1, 2, 3], [1, 2, 3, 0])
>>> G.number_of_nodes()
4
>>> G.number_of_edges()
4
>>> G.clear()
>>> G.number_of_nodes()
0
>>> G.number_of_edges()
0
"""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_index = utils.zero_index(0)
self._msg_frame.clear()
See Also
--------
out_degree
def clear_cache(self):
"""Clear all cached graph structures such as adjmat.
By default, all graph structure related sparse matrices (e.g. adjmat, incmat)
are cached so they could be reused with the cost of extra memory consumption.
This function can be used to clear the cached matrices if memory is an issue.
"""
if is_all(v):
v = utils.toindex(slice(0, self.number_of_nodes()))
else:
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
self._graph.clear_cache()
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
......@@ -2730,6 +2750,7 @@ class DGLGraph(object):
subgraphs
edge_subgraph
"""
from . import subgraph
induced_nodes = utils.toindex(nodes)
sgi = self._graph.node_subgraph(induced_nodes)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
......@@ -2757,6 +2778,7 @@ class DGLGraph(object):
DGLSubGraph
subgraph
"""
from . import subgraph
induced_nodes = [utils.toindex(n) for n in nodes]
sgis = self._graph.node_subgraphs(induced_nodes)
return [subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
......@@ -2804,6 +2826,7 @@ class DGLGraph(object):
DGLSubGraph
subgraph
"""
from . import subgraph
induced_edges = utils.toindex(edges)
sgi = self._graph.edge_subgraph(induced_edges)
return subgraph.DGLSubGraph(self, sgi.induced_nodes, sgi.induced_edges, sgi)
......
......@@ -675,8 +675,10 @@ class GraphIndex(object):
rst = _nonuniform_sampling(self, node_prob, seed_ids, neighbor_type, num_hops,
expand_factor)
return [SubgraphIndex(rst(i), self, utils.toindex(rst(num_subgs + i)),
utils.toindex(rst(num_subgs * 2 + i))) for i in range(num_subgs)]
return [NodeFlowIndex(rst(i), self, utils.toindex(rst(num_subgs + i)),
utils.toindex(rst(num_subgs * 2 + i)),
utils.toindex(rst(num_subgs * 3 + i)),
utils.toindex(rst(num_subgs * 4 + i))) for i in range(num_subgs)]
def to_networkx(self):
"""Convert to networkx graph.
......@@ -821,8 +823,7 @@ class SubgraphIndex(GraphIndex):
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(parent.is_multigraph(), parent.is_readonly())
self._handle = handle
super(SubgraphIndex, self).__init__(handle, parent.is_multigraph(), parent.is_readonly())
self._parent = parent
self._induced_nodes = induced_nodes
self._induced_edges = induced_edges
......@@ -869,6 +870,75 @@ class SubgraphIndex(GraphIndex):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
class NodeFlowIndex(GraphIndex):
"""Graph index for a NodeFlow graph.
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
node_mapping : utils.Index
This maps nodes to the parent graph.
edge_mapping : utils.Index
The maps edges to the parent graph.
layers: utils.Index
The offsets of the layers.
flows: utils.Index
The offsets of the flows.
"""
def __init__(self, handle, parent, node_mapping, edge_mapping, layers, flows):
super(NodeFlowIndex, self).__init__(handle, parent.is_multigraph(), parent.is_readonly())
self._parent = parent
self._node_mapping = node_mapping
self._edge_mapping = edge_mapping
self._layers = layers
self._flows = flows
@property
def node_mapping(self):
"""Return the node mapping to the parent graph.
Returns
-------
utils.Index
The node mapping.
"""
return self._node_mapping
@property
def edge_mapping(self):
"""Return the edge mapping to the parent graph.
Returns
-------
utils.Index
The edge mapping.
"""
return self._edge_mapping
@property
def layers(self):
"""Return layer offsets.
"""
return self._layers
@property
def flows(self):
"""Return flow offsets.
"""
return self._flows
def __getstate__(self):
raise NotImplementedError(
"SubgraphIndex pickling is not supported yet.")
def __setstate__(self, state):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
def map_to_subgraph_nid(subgraph, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
......@@ -888,6 +958,34 @@ def map_to_subgraph_nid(subgraph, parent_nids):
return utils.toindex(_CAPI_DGLMapSubgraphNID(subgraph.induced_nodes.todgltensor(),
parent_nids.todgltensor()))
def map_to_nodeflow_nid(nflow, layer_id, parent_nids):
"""Map parent node Ids to NodeFlow node Ids in a certain layer.
Parameters
----------
nflow : NodeFlowIndex
The graph index of a NodeFlow.
layer_id : int
The layer Id
parent_nids: utils.Index
Node Ids in the parent graph.
Returns
-------
utils.Index
Node Ids in the NodeFlow.
"""
mapping = nflow.node_mapping.tousertensor()
layers = nflow.layers.tonumpy()
start = int(layers[layer_id])
end = int(layers[layer_id + 1])
mapping = mapping[start:end]
mapping = utils.toindex(mapping)
return utils.toindex(_CAPI_DGLMapSubgraphNID(mapping.todgltensor(),
parent_nids.todgltensor()))
def disjoint_union(graphs):
"""Return a disjoint union of the input graphs.
......
"""Class for NodeFlow data structure."""
from __future__ import absolute_import
from .base import ALL, is_all, DGLError
from . import backend as F
from .frame import Frame, FrameRef
from .graph import DGLBaseGraph
from .runtime import ir, scheduler, Runtime
from . import utils
from .view import LayerView, BlockView
def _copy_to_like(arr1, arr2):
return F.copy_to(arr1, F.context(arr2))
def _get_frame(frame, names, ids):
col_dict = {name: frame[name][_copy_to_like(ids, frame[name])] for name in names}
if len(col_dict) == 0:
return FrameRef(Frame(num_rows=len(ids)))
else:
return FrameRef(Frame(col_dict))
def _update_frame(frame, names, ids, new_frame):
col_dict = {name: new_frame[name] for name in names}
if len(col_dict) > 0:
frame.update_rows(ids, FrameRef(Frame(col_dict)), inplace=True)
class NodeFlow(DGLBaseGraph):
"""The NodeFlow class stores the sampling results of Neighbor sampling and Layer-wise sampling.
These sampling algorithms generate graphs with multiple layers. The edges connect the nodes
between two layers while there don't exist edges between the nodes in the same layer.
We store multiple layers of the sampling results in a single graph. We store extra information,
such as the node and edge mapping from the NodeFlow graph to the parent graph.
Parameters
----------
parent : DGLGraph
The parent graph
graph_index : NodeFlowIndex
The graph index of the NodeFlow graph.
"""
def __init__(self, parent, graph_idx):
super(NodeFlow, self).__init__(graph_idx)
self._parent = parent
self._node_mapping = graph_idx.node_mapping
self._edge_mapping = graph_idx.edge_mapping
self._layer_offsets = graph_idx.layers.tonumpy()
self._block_offsets = graph_idx.flows.tonumpy()
self._node_frames = [FrameRef(Frame(num_rows=self.layer_size(i))) \
for i in range(self.num_layers)]
self._edge_frames = [FrameRef(Frame(num_rows=self.block_size(i))) \
for i in range(self.num_blocks)]
# registered functions
self._message_funcs = [None] * self.num_blocks
self._reduce_funcs = [None] * self.num_blocks
self._apply_node_funcs = [None] * self.num_blocks
self._apply_edge_funcs = [None] * self.num_blocks
def _get_layer_id(self, layer_id):
"""The layer Id might be negative. We need to convert it to the actual layer Id.
"""
if layer_id >= 0:
return layer_id
else:
return self.num_layers + layer_id
def _get_block_id(self, block_id):
"""The block Id might be negative. We need to convert it to the actual block Id.
"""
if block_id >= 0:
return block_id
else:
return self.num_blocks + block_id
def _get_node_frame(self, layer_id):
return self._node_frames[layer_id]
def _get_edge_frame(self, flow_id):
return self._edge_frames[flow_id]
@property
def num_layers(self):
"""Get the number of layers.
Returns
-------
int
the number of layers
"""
return len(self._layer_offsets) - 1
@property
def num_blocks(self):
"""Get the number of blocks.
Returns
-------
int
the number of blocks
"""
return self.num_layers - 1
@property
def layers(self):
"""Return a LayerView of this NodeFlow.
This is mainly for usage like:
* `g.layers[2].data['h']` to get the node features of layer#2.
* `g.layers(2)` to get the nodes of layer#2.
"""
return LayerView(self)
@property
def blocks(self):
"""Return a BlockView of this NodeFlow.
This is mainly for usage like:
* `g.blocks[1,2].data['h']` to get the edge features of blocks from layer#1 to layer#2.
* `g.blocks(1, 2)` to get the edge ids of blocks #1->#2.
"""
return BlockView(self)
def layer_size(self, layer_id):
"""Return the number of nodes in a specified layer.
Parameters
----------
layer_id : int
the specified layer to return the number of nodes.
"""
layer_id = self._get_layer_id(layer_id)
return int(self._layer_offsets[layer_id + 1]) - int(self._layer_offsets[layer_id])
def block_size(self, block_id):
"""Return the number of edges in a specified block.
Parameters
----------
block_id : int
the specified block to return the number of edges.
"""
block_id = self._get_block_id(block_id)
return int(self._block_offsets[block_id + 1]) - int(self._block_offsets[block_id])
def copy_from_parent(self, node_embed_names=ALL, edge_embed_names=ALL):
"""Copy node/edge features from the parent graph.
Parameters
----------
node_embed_names : a list of lists of strings, optional
The names of embeddings in each layer.
edge_embed_names : a list of lists of strings, optional
The names of embeddings in each block.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
if is_all(node_embed_names):
for i in range(self.num_layers):
nid = utils.toindex(self.layer_parent_nid(i))
self._node_frames[i] = FrameRef(Frame(self._parent._node_frame[nid]))
elif node_embed_names is not None:
assert isinstance(node_embed_names, list) \
and len(node_embed_names) == self.num_layers, \
"The specified embedding names should be the same as the number of layers."
for i in range(self.num_layers):
nid = self.layer_parent_nid(i)
self._node_frames[i] = _get_frame(self._parent._node_frame,
node_embed_names[i], nid)
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
if is_all(edge_embed_names):
for i in range(self.num_blocks):
eid = utils.toindex(self.block_parent_eid(i))
self._edge_frames[i] = FrameRef(Frame(self._parent._edge_frame[eid]))
elif edge_embed_names is not None:
assert isinstance(edge_embed_names, list) \
and len(edge_embed_names) == self.num_blocks, \
"The specified embedding names should be the same as the number of flows."
for i in range(self.num_blocks):
eid = self.block_parent_eid(i)
self._edge_frames[i] = _get_frame(self._parent._edge_frame,
edge_embed_names[i], eid)
def copy_to_parent(self, node_embed_names=ALL, edge_embed_names=ALL):
"""Copy node/edge embeddings to the parent graph.
Parameters
----------
node_embed_names : a list of lists of strings, optional
The names of embeddings in each layer.
edge_embed_names : a list of lists of strings, optional
The names of embeddings in each block.
"""
if self._parent._node_frame.num_rows != 0 and self._parent._node_frame.num_columns != 0:
if is_all(node_embed_names):
for i in range(self.num_layers):
nid = utils.toindex(self.layer_parent_nid(i))
# We should write data back directly.
self._parent._node_frame.update_rows(nid, self._node_frames[i], inplace=True)
elif node_embed_names is not None:
assert isinstance(node_embed_names, list) \
and len(node_embed_names) == self.num_layers, \
"The specified embedding names should be the same as the number of layers."
for i in range(self.num_layers):
nid = utils.toindex(self.layer_parent_nid(i))
_update_frame(self._parent._node_frame, node_embed_names[i], nid,
self._node_frames[i])
if self._parent._edge_frame.num_rows != 0 and self._parent._edge_frame.num_columns != 0:
if is_all(edge_embed_names):
for i in range(self.num_blocks):
eid = utils.toindex(self.block_parent_eid(i))
self._parent._edge_frame.update_rows(eid, self._edge_frames[i], inplace=True)
elif edge_embed_names is not None:
assert isinstance(edge_embed_names, list) \
and len(edge_embed_names) == self.num_blocks, \
"The specified embedding names should be the same as the number of flows."
for i in range(self.num_blocks):
eid = utils.toindex(self.block_parent_eid(i))
_update_frame(self._parent._edge_frame, edge_embed_names[i], eid,
self._edge_frames[i])
def map_to_parent_nid(self, nid):
"""This maps the child node Ids to the parent Ids.
Parameters
----------
nid : tensor
The node ID array in the NodeFlow graph.
Returns
-------
Tensor
The parent node id array.
"""
return self._node_mapping.tousertensor()[nid]
def map_to_parent_eid(self, eid):
"""This maps the child edge Ids to the parent Ids.
Parameters
----------
nid : tensor
The edge ID array in the NodeFlow graph.
Returns
-------
Tensor
The parent edge id array.
"""
return self._edge_mapping.tousertensor()[eid]
def layer_in_degree(self, layer_id):
"""Return the in-degree of the nodes in the specified layer.
Parameters
----------
layer_id : int
The layer Id.
Returns
-------
Tensor
The degree of the nodes in the specified layer.
"""
return self._graph.in_degrees(utils.toindex(self.layer_nid(layer_id))).tousertensor()
def layer_out_degree(self, layer_id):
"""Return the out-degree of the nodes in the specified layer.
Parameters
----------
layer_id : int
The layer Id.
Returns
-------
Tensor
The degree of the nodes in the specified layer.
"""
return self._graph.out_degrees(utils.toindex(self.layer_nid(layer_id))).tousertensor()
def layer_nid(self, layer_id):
"""Get the node Ids in the specified layer.
Parameters
----------
layer_id : int
The layer to get the node Ids.
Returns
-------
Tensor
The node id array.
"""
layer_id = self._get_layer_id(layer_id)
assert layer_id + 1 < len(self._layer_offsets)
start = self._layer_offsets[layer_id]
end = self._layer_offsets[layer_id + 1]
return F.arange(int(start), int(end))
def layer_parent_nid(self, layer_id):
"""Get the node Ids of the parent graph in the specified layer
Parameters
----------
layer_id : int
The layer to get the node Ids.
Returns
-------
Tensor
The parent node id array.
"""
layer_id = self._get_layer_id(layer_id)
assert layer_id + 1 < len(self._layer_offsets)
start = self._layer_offsets[layer_id]
end = self._layer_offsets[layer_id + 1]
return self._node_mapping.tousertensor()[start:end]
def block_eid(self, block_id):
"""Get the edge Ids in the specified block.
Parameters
----------
block_id : int
the specified block to get edge Ids.
Returns
-------
Tensor
The edge id array.
"""
block_id = self._get_block_id(block_id)
start = self._block_offsets[block_id]
end = self._block_offsets[block_id + 1]
return F.arange(int(start), int(end))
def block_parent_eid(self, block_id):
"""Get the edge Ids of the parent graph in the specified block.
Parameters
----------
block_id : int
the specified block to get edge Ids.
Returns
-------
Tensor
The parent edge id array.
"""
block_id = self._get_block_id(block_id)
start = self._block_offsets[block_id]
end = self._block_offsets[block_id + 1]
return self._edge_mapping.tousertensor()[start:end]
def set_n_initializer(self, initializer, layer_id=ALL, field=None):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape, data type
and device context.
When a subset of the nodes are assigned a new feature, initializer is
used to create feature for rest of the nodes.
Parameters
----------
initializer : callable
The initializer.
layer_id : int
the layer to set the initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
"""
if is_all(layer_id):
for i in range(self.num_layers):
self._node_frames[i].set_initializer(initializer, field)
else:
self._node_frames[i].set_initializer(initializer, field)
def set_e_initializer(self, initializer, block_id=ALL, field=None):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape, data
type and device context.
When a subset of the edges are assigned a new feature, initializer is
used to create feature for rest of the edges.
Parameters
----------
initializer : callable
The initializer.
block_id : int
the block to set the initializer.
field : str, optional
The feature field name. Default is set an initializer for all the
feature fields.
"""
if is_all(block_id):
for i in range(self.num_blocks):
self._edge_frames[i].set_initializer(initializer, field)
else:
self._edge_frames[block_id].set_initializer(initializer, field)
def register_message_func(self, func, block_id=ALL):
"""Register global message function for a block.
Once registered, ``func`` will be used as the default
message function in message passing operations, including
:func:`block_compute`, :func:`prop_flow`.
Parameters
----------
func : callable
Message function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
block_id : int or ALL
the block to register the message function.
"""
if is_all(block_id):
self._message_funcs = [func] * self.num_blocks
else:
self._message_funcs[block_id] = func
def register_reduce_func(self, func, block_id=ALL):
"""Register global message reduce function for a block.
Once registered, ``func`` will be used as the default
message reduce function in message passing operations, including
:func:`block_compute`, :func:`prop_flow`.
Parameters
----------
func : callable
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
block_id : int or ALL
the block to register the reduce function.
"""
if is_all(block_id):
self._reduce_funcs = [func] * self.num_blocks
else:
self._reduce_funcs[block_id] = func
def register_apply_node_func(self, func, block_id=ALL):
"""Register global node apply function for a block.
Once registered, ``func`` will be used as the default apply
node function. Related operations include :func:`apply_layer`,
:func:`block_compute`, :func:`prop_flow`.
Parameters
----------
func : callable
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
block_id : int or ALL
the block to register the apply node function.
"""
if is_all(block_id):
self._apply_node_funcs = [func] * self.num_blocks
else:
self._apply_node_funcs[block_id] = func
def register_apply_edge_func(self, func, block_id=ALL):
"""Register global edge apply function for a block.
Once registered, ``func`` will be used as the default apply
edge function in :func:`apply_block`.
Parameters
----------
func : callable
Apply function on the edge. The function should be
an :mod:`Edge UDF <dgl.udf>`.
block_id : int or ALL
the block to register the apply edge function.
"""
if is_all(block_id):
self._apply_edge_funcs = [func] * self.num_blocks
else:
self._apply_edge_funcs[block_id] = func
def apply_layer(self, layer_id, func="default", v=ALL, inplace=False):
"""Apply node update function on the node embeddings in the specified layer.
Parameters
----------
layer_id : int
The specified layer to update node embeddings.
func : callable or None, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
v : a list of vertex Ids or ALL.
The vertices to run the node update function.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
"""
if func == "default":
func = self._apply_node_funcs[layer_id]
if is_all(v):
v = utils.toindex(slice(0, self.layer_size(layer_id)))
else:
v = v - int(self._layer_offsets[layer_id])
v = utils.toindex(v)
with ir.prog() as prog:
scheduler.schedule_nodeflow_apply_nodes(graph=self,
layer_id=layer_id,
v=v,
apply_func=func,
inplace=inplace)
Runtime.run(prog)
def apply_block(self, block_id, func="default", edges=ALL, inplace=False):
"""Apply edge update function on the edge embeddings in the specified layer.
Parameters
----------
block_id : int
The specified block to update edge embeddings.
func : callable or None, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
edges : a list of edge Ids or ALL.
The edges to run the edge update function.
inplace : bool, optional
If True, update will be done in place, but autograd will break.
"""
if func == "default":
func = self._apply_edge_funcs[block_id]
assert func is not None
def _layer_local_nid(layer_id):
return F.arange(0, self.layer_size(layer_id))
if is_all(edges):
u = utils.toindex(_layer_local_nid(block_id))
v = utils.toindex(_layer_local_nid(block_id + 1))
eid = utils.toindex(slice(0, self.block_size(block_id)))
elif isinstance(edges, tuple):
u, v = edges
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(utils.toindex(u), utils.toindex(v))
u = utils.toindex(u.tousertensor() - int(self._layer_offsets[block_id]))
v = utils.toindex(v.tousertensor() - int(self._layer_offsets[block_id + 1]))
eid = utils.toindex(eid.tousertensor() - int(self._block_offsets[block_id]))
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
u = utils.toindex(u.tousertensor() - int(self._layer_offsets[block_id]))
v = utils.toindex(v.tousertensor() - int(self._layer_offsets[block_id + 1]))
eid = utils.toindex(edges - int(self._block_offsets[block_id]))
with ir.prog() as prog:
scheduler.schedule_nodeflow_apply_edges(graph=self,
block_id=block_id,
u=u,
v=v,
eid=eid,
apply_func=func,
inplace=inplace)
Runtime.run(prog)
def _glb2lcl_nid(self, nid, layer_id):
layer_id = self._get_layer_id(layer_id)
return nid - int(self._layer_offsets[layer_id])
def _glb2lcl_eid(self, eid, block_id):
block_id = self._get_block_id(block_id)
return eid - int(self._block_offsets[block_id])
def block_compute(self, block_id, message_func="default", reduce_func="default",
apply_node_func="default", v=ALL, inplace=False):
"""Perform the computation on the specified block. It's similar to `pull`
in DGLGraph.
On the given block i, it runs `pull` on nodes in layer i+1, which generates
messages on edges in block i, runs the reduce function and node update
function on nodes in layer i+1.
Parameters
----------
block_id : int
The block to run the computation.
message_func : callable, optional
Message function on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_func : callable, optional
Reduce function on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_func : callable, optional
Apply function on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
v : a list of vertex Ids or ALL.
The specified nodes in layer i+1 to run the computation.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if message_func == "default":
message_func = self._message_funcs[block_id]
if reduce_func == "default":
reduce_func = self._reduce_funcs[block_id]
if apply_node_func == "default":
apply_node_func = self._apply_node_funcs[block_id]
assert message_func is not None
assert reduce_func is not None
if is_all(v):
dest_nodes = utils.toindex(self.layer_nid(block_id + 1))
u, v, _ = self._graph.in_edges(dest_nodes)
u = utils.toindex(self._glb2lcl_nid(u.tousertensor(), block_id))
v = utils.toindex(self._glb2lcl_nid(v.tousertensor(), block_id + 1))
dest_nodes = utils.toindex(F.arange(0, self.layer_size(block_id + 1)))
eid = utils.toindex(F.arange(0, self.block_size(block_id)))
else:
dest_nodes = utils.toindex(v)
u, v, eid = self._graph.in_edges(dest_nodes)
assert len(u) > 0, "block_compute must run on edges"
u = utils.toindex(self._glb2lcl_nid(u.tousertensor(), block_id))
v = utils.toindex(self._glb2lcl_nid(v.tousertensor(), block_id + 1))
dest_nodes = utils.toindex(self._glb2lcl_nid(dest_nodes.tousertensor(),
block_id + 1))
eid = utils.toindex(self._glb2lcl_eid(eid.tousertensor(), block_id))
with ir.prog() as prog:
scheduler.schedule_nodeflow_compute(graph=self,
block_id=block_id,
u=u,
v=v,
eid=eid,
dest_nodes=dest_nodes,
message_func=message_func,
reduce_func=reduce_func,
apply_func=apply_node_func,
inplace=inplace)
Runtime.run(prog)
def prop_flow(self, message_funcs="default", reduce_funcs="default",
apply_node_funcs="default", flow_range=ALL, inplace=False):
"""Perform the computation on flows. By default, it runs on all blocks, one-by-one.
On block i, it runs `pull` on nodes in layer i+1, which generates
messages on edges in block i, runs the reduce function and node update
function on nodes in layer i+1.
Users can specify a list of message functions, reduce functions and
node apply functions, one for each block. Thus, when a list is given,
the length of the list should be the same as the number of blocks.
Parameters
----------
message_funcs : a callable, a list of callable, optional
Message functions on the edges. The function should be
an :mod:`Edge UDF <dgl.udf>`.
reduce_funcs : a callable, a list of callable, optional
Reduce functions on the node. The function should be
a :mod:`Node UDF <dgl.udf>`.
apply_node_funcs : a callable, a list of callable, optional
Apply functions on the nodes. The function should be
a :mod:`Node UDF <dgl.udf>`.
flow_range : int or a slice or ALL.
The specified blocks to run the computation.
inplace: bool, optional
If True, update will be done in place, but autograd will break.
"""
if is_all(flow_range):
flow_range = range(0, self.num_blocks)
elif isinstance(flow_range, slice):
if slice.step is not 1:
raise DGLError("We can't propogate flows and skip some of them")
flow_range = range(flow_range.start, flow_range.stop)
else:
raise DGLError("unknown flow range")
for i in flow_range:
if message_funcs == "default":
message_func = self._message_funcs[i]
elif isinstance(message_funcs, list):
message_func = message_funcs[i]
else:
message_func = message_funcs
if reduce_funcs == "default":
reduce_func = self._reduce_funcs[i]
elif isinstance(reduce_funcs, list):
reduce_func = reduce_funcs[i]
else:
reduce_func = reduce_funcs
if apply_node_funcs == "default":
apply_node_func = self._apply_node_funcs[i]
elif isinstance(apply_node_funcs, list):
apply_node_func = apply_node_funcs[i]
else:
apply_node_func = apply_node_funcs
self.block_compute(i, message_func, reduce_func, apply_node_func,
inplace=inplace)
def create_full_node_flow(g, num_layers):
"""Convert a full graph to NodeFlow to run a L-layer GNN model.
Parameters
----------
g : DGLGraph
a DGL graph
num_layers : int
The number of layers
Returns
-------
NodeFlow
a NodeFlow with a specified number of layers.
"""
seeds = [utils.toindex(F.arange(0, g.number_of_nodes()))]
nfi = g._graph.neighbor_sampling(seeds, g.number_of_nodes(), num_layers, 'in', None)
return NodeFlow(g, nfi[0])
......@@ -51,7 +51,7 @@ def schedule_send(graph, u, v, eid, message_func):
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
msg = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, message_func)
msg = _gen_send(graph, var_nf, var_nf, var_ef, var_u, var_v, var_eid, message_func)
ir.WRITE_ROW_(var_mf, var_eid, msg)
# set message indicator to 1
graph._msg_index = graph._msg_index.set_items(eid, 1)
......@@ -144,9 +144,10 @@ def schedule_snr(graph,
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_adj_matrix_uv(graph, (u, v), recv_nodes)
adj_creator = lambda: spmv.build_adj_matrix_uv((u, v), recv_nodes, graph.number_of_nodes())
inc_creator = lambda: spmv.build_inc_matrix_dst(v, recv_nodes)
reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
reduced_feat = _gen_send_reduce(graph, graph._node_frame, graph._node_frame,
graph._edge_frame, message_func, reduce_func,
var_eid, var_recv_nodes,
uv_getter, adj_creator, inc_creator)
# generate apply schedule
......@@ -191,7 +192,8 @@ def schedule_update_all(graph,
return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_adj_matrix_graph(graph)
inc_creator = lambda: spmv.build_inc_matrix_graph(graph)
reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
reduced_feat = _gen_send_reduce(graph, graph._node_frame, graph._node_frame,
graph._edge_frame, message_func, reduce_func,
var_eid, var_recv_nodes,
uv_getter, adj_creator, inc_creator)
# generate optional apply
......@@ -232,6 +234,43 @@ def schedule_apply_nodes(graph,
else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat)
def schedule_nodeflow_apply_nodes(graph,
layer_id,
v,
apply_func,
inplace):
"""get apply nodes schedule in NodeFlow.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
layer_id : int
The layer where we apply node update function.
v : utils.Index
Nodes to apply
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
Returns
-------
A list of executors for DGL Runtime
"""
var_nf = var.FEAT_DICT(graph._get_node_frame(layer_id), name='nf')
var_v = var.IDX(v)
v_nf = ir.READ_ROW(var_nf, var_v)
def _afunc_wrapper(node_data):
nbatch = NodeBatch(graph, v, node_data)
return apply_func(nbatch)
afunc = var.FUNC(_afunc_wrapper)
applied_feat = ir.NODE_UDF(afunc, v_nf)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_v, applied_feat)
else:
ir.WRITE_ROW_(var_nf, var_v, applied_feat)
def schedule_apply_edges(graph,
u, v, eid,
apply_func,
......@@ -277,6 +316,54 @@ def schedule_apply_edges(graph,
else:
ir.WRITE_ROW_(var_ef, var_eid, new_fdedge)
def schedule_nodeflow_apply_edges(graph, block_id,
u, v, eid,
apply_func,
inplace):
"""get apply edges schedule in NodeFlow.
Parameters
----------
graph: DGLGraph
The DGLGraph to use
block_id : int
The block whose edges we apply edge update function.
u : utils.Index
Source nodes of edges to apply
v : utils.Index
Destination nodes of edges to apply
eid : utils.Index
Ids of sending edges
apply_func: callable
The apply edge function
inplace: bool
If True, the update will be done in place
Returns
-------
A list of executors for DGL Runtime
"""
# vars
in_var_nf = var.FEAT_DICT(graph._get_node_frame(block_id), name='in_nf')
out_var_nf = var.FEAT_DICT(graph._get_node_frame(block_id + 1), name='out_nf')
var_ef = var.FEAT_DICT(graph._get_edge_frame(block_id), name='ef')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# schedule apply edges
fdsrc = ir.READ_ROW(in_var_nf, var_u)
fddst = ir.READ_ROW(out_var_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid)
def _efunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch(graph, (u, v, eid), src_data, edge_data, dst_data)
return apply_func(ebatch)
_efunc = var.FUNC(_efunc_wrapper)
new_fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst)
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, new_fdedge)
else:
ir.WRITE_ROW_(var_ef, var_eid, new_fdedge)
def schedule_push(graph,
u,
message_func,
......@@ -349,9 +436,10 @@ def schedule_pull(graph,
var_eid = var.IDX(eid)
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_adj_matrix_uv(graph, (u, v), pull_nodes)
adj_creator = lambda: spmv.build_adj_matrix_uv((u, v), pull_nodes, graph.number_of_nodes())
inc_creator = lambda: spmv.build_inc_matrix_dst(v, pull_nodes)
reduced_feat = _gen_send_reduce(graph, message_func, reduce_func,
reduced_feat = _gen_send_reduce(graph, graph._node_frame, graph._node_frame,
graph._edge_frame, message_func, reduce_func,
var_eid, var_pull_nodes,
uv_getter, adj_creator, inc_creator)
# generate optional apply
......@@ -361,7 +449,6 @@ def schedule_pull(graph,
else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph,
u, v, eid,
apply_func,
......@@ -403,6 +490,68 @@ def schedule_group_apply_edge(graph,
else:
ir.WRITE_ROW_(var_ef, var_eid, var_out)
def schedule_nodeflow_compute(graph,
block_id,
u, v, eid,
dest_nodes,
message_func,
reduce_func,
apply_func,
inplace):
"""get flow compute schedule in NodeFlow
Parameters
----------
graph: DGLGraph
The DGLGraph to use
block_id : int
The block where we perform computation.
u : utils.Index
Source nodes of edges to apply
v : utils.Index
Destination nodes of edges to apply
eid : utils.Index
Ids of sending edges
message_func: callable or list of callable
The message function
reduce_func: callable or list of callable
The reduce function
apply_func: callable
The apply node function
inplace: bool
If True, the update will be done in place
"""
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier.
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
schedule_nodeflow_apply_nodes(graph, block_id + 1, v, apply_func, inplace)
else:
# create vars
var_nf = var.FEAT_DICT(graph._get_node_frame(block_id + 1), name='out_nf')
var_dest_nodes = var.IDX(dest_nodes, name='dest_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_adj_matrix_uv((u, v), dest_nodes,
graph.layer_size(block_id))
inc_creator = lambda: spmv.build_inc_matrix_dst(v, dest_nodes)
reduced_feat = _gen_send_reduce(graph, graph._get_node_frame(block_id),
graph._get_node_frame(block_id + 1),
graph._get_edge_frame(block_id),
message_func, reduce_func,
var_eid, var_dest_nodes,
uv_getter, adj_creator, inc_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_dest_nodes, var_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_dest_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_dest_nodes, final_feat)
def _check_builtin_func_list(func_list):
"""Check whether func_list only contains builtin functions."""
......@@ -513,6 +662,9 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
def _gen_send_reduce(
graph,
src_node_frame,
dst_node_frame,
edge_frame,
message_func,
reduce_func,
var_send_edges,
......@@ -529,6 +681,12 @@ def _gen_send_reduce(
----------
graph : DGLGraph
The graph
src_node_frame : NodeFrame
The node frame of the source nodes.
dst_node_frame : NodeFrame
The node frame of the destination nodes.
edge_frame : NodeFrame
The frame for the edges between the source and destination nodes.
message_func : callable, list of builtins
The message func(s).
reduce_func : callable, list of builtins
......@@ -553,8 +711,9 @@ def _gen_send_reduce(
reduce_nodes = var_reduce_nodes.data
# arg vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_src_nf = var.FEAT_DICT(src_node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(dst_node_frame, name='nf')
var_ef = var.FEAT_DICT(edge_frame, name='ef')
var_eid = var_send_edges
# format the input functions
......@@ -567,7 +726,7 @@ def _gen_send_reduce(
# The frame has the same size and schemes of the
# node frame.
# TODO(minjie): should replace this with an IR call to make the program stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(reduce_nodes)))
tmpframe = FrameRef(frame_like(dst_node_frame._frame, len(reduce_nodes)))
var_out = var.FEAT_DICT(data=tmpframe)
if mfunc_is_list and rfunc_is_list:
......@@ -575,7 +734,7 @@ def _gen_send_reduce(
# analyze v2v spmv
spmv_pairs, mfunc, rfunc = spmv.analyze_v2v_spmv(graph, mfunc, rfunc)
adj = adj_creator()
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_nf, var_ef, var_eid, var_out)
spmv.gen_v2v_spmv_schedule(adj, spmv_pairs, var_src_nf, var_ef, var_eid, var_out)
if len(mfunc) == 0:
# All mfunc and rfunc have been converted to v2v spmv.
......@@ -591,7 +750,7 @@ def _gen_send_reduce(
# generate UDF send schedule
var_u, var_v = uv_getter()
var_mf = _gen_send(graph, var_nf, var_ef, var_u, var_v, var_eid, mfunc)
var_mf = _gen_send(graph, var_src_nf, var_dst_nf, var_ef, var_u, var_v, var_eid, mfunc)
if rfunc_is_list:
# UDF message + builtin reducer
......@@ -610,13 +769,13 @@ def _gen_send_reduce(
# gen degree bucketing schedule for UDF recv
mid = utils.toindex(slice(0, len(var_v.data))) # message id is from 0~|dst|
db.gen_degree_bucketing_schedule(
graph, rfunc, mid, var_v.data, reduce_nodes, var_nf, var_mf, var_out)
graph, rfunc, mid, var_v.data, reduce_nodes, var_dst_nf, var_mf, var_out)
return var_out
def _gen_send(graph, nfr, efr, u, v, eid, mfunc):
def _gen_send(graph, src_nfr, dst_nfr, efr, u, v, eid, mfunc):
"""Internal function to generate send schedule."""
fdsrc = ir.READ_ROW(nfr, u)
fddst = ir.READ_ROW(nfr, v)
fdsrc = ir.READ_ROW(src_nfr, u)
fddst = ir.READ_ROW(dst_nfr, v)
fdedge = ir.READ_ROW(efr, eid)
def _mfunc_wrapper(src_data, edge_data, dst_data):
ebatch = EdgeBatch(graph, (u.data, v.data, eid.data),
......
......@@ -151,7 +151,7 @@ def build_adj_matrix_graph(graph):
_, shuffle_idx = gidx.adjacency_matrix(False, F.cpu())
return lambda ctx: gidx.adjacency_matrix(False, ctx)[0], shuffle_idx
def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
def _build_adj_matrix_index_uv(edges, reduce_nodes, num_sources):
"""Build adj matrix index and shape using the given (u, v) edges.
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
......@@ -164,13 +164,13 @@ def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
Paramters
---------
graph : DGLGraph
The graph
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
num_sources : int
The number of source nodes.
Returns
-------
......@@ -185,14 +185,14 @@ def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
u = u.tousertensor()
v = v.tousertensor()
new_v = old2new[v] # FIXME(minjie): no use []
n = graph.number_of_nodes()
n = num_sources
m = len(reduce_nodes)
row = F.unsqueeze(new_v, 0)
col = F.unsqueeze(u, 0)
idx = F.cat([row, col], dim=0)
return ('coo', idx), (m, n)
def build_adj_matrix_uv(graph, edges, reduce_nodes):
def build_adj_matrix_uv(edges, reduce_nodes, num_sources):
"""Build adj matrix using the given (u, v) edges and target nodes.
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
......@@ -201,13 +201,13 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
Parameters
---------
graph : DGLGraph
The graph
edges : tuple of utils.Index
(u, v)
reduce_nodes : utils.Index
The nodes to reduce messages, which will be target dimension
of the adjmat. The nodes include unique(v) and zero-degree-nodes.
num_sources : int
The number of source nodes.
Returns
-------
......@@ -217,7 +217,7 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
"""
sp_idx, shape = _build_adj_matrix_index_uv(graph, edges, reduce_nodes)
sp_idx, shape = _build_adj_matrix_index_uv(edges, reduce_nodes, num_sources)
u, _ = edges
nnz = len(u)
# FIXME(minjie): data type
......
......@@ -142,3 +142,103 @@ class EdgeDataView(MutableMapping):
def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
class LayerView(object):
"""A LayerView class to act as nflow.layers for a NodeFlow.
Can be used to get a list of current nodes and get and set node data.
"""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.num_layers()
def __getitem__(self, layer):
if not isinstance(layer, int):
raise DGLError('Currently we only support the view of one layer')
return NodeSpace(data=LayerDataView(self._graph, layer))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class LayerDataView(MutableMapping):
"""The data view class when G.layers[...].data is called.
"""
__slots__ = ['_graph', '_layer']
def __init__(self, graph, layer):
self._graph = graph
self._layer = layer
def __getitem__(self, key):
return self._graph._node_frames[self._layer][key]
def __setitem__(self, key, val):
self._graph._node_frames[self._layer][key] = val
def __delitem__(self, key):
del self._graph._node_frames[self._layer][key]
def __len__(self):
return len(self._graph._node_frames[self._layer])
def __iter__(self):
return iter(self._graph._node_frames[self._layer])
def __repr__(self):
data = self._graph._node_frames[self._layer]
return repr({key : data[key] for key in data})
class BlockView(object):
"""A BlockView class to act as nflow.blocks for a NodeFlow.
Can be used to get a list of current edges and get and set edge data.
"""
__slots__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.num_blocks
def __getitem__(self, flow):
if not isinstance(flow, int):
raise DGLError('Currently we only support the view of one flow')
return EdgeSpace(data=BlockDataView(self._graph, flow))
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class BlockDataView(MutableMapping):
"""The data view class when G.blocks[...].data is called.
"""
__slots__ = ['_graph', '_flow']
def __init__(self, graph, flow):
self._graph = graph
self._flow = flow
def __getitem__(self, key):
return self._graph._edge_frames[self._flow][key]
def __setitem__(self, key, val):
self._graph._edge_frames[self._flow][key] = val
def __delitem__(self, key):
del self._graph._edge_frames[self._flow][key]
def __len__(self):
return len(self._graph._edge_frames[self._flow])
def __iter__(self):
return iter(self._graph._edge_frames[self._flow])
def __repr__(self):
data = self._graph._edge_frames[self._flow]
return repr({key : data[key] for key in data})
......@@ -24,7 +24,7 @@ DLManagedTensor* CreateTmpDLManagedTensor(const DGLArgValue& arg) {
PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
auto body = [vec](DGLArgs args, DGLRetValue* rv) {
const int which = args[0];
const uint64_t which = args[0];
if (which >= vec.size()) {
LOG(FATAL) << "invalid choice";
} else {
......
......@@ -4,6 +4,7 @@
* \brief DGL graph index implementation
*/
#include <dgl/graph.h>
#include <dgl/sampler.h>
#include <algorithm>
#include <unordered_map>
#include <set>
......
......@@ -6,6 +6,7 @@
#include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_op.h>
#include <dgl/sampler.h>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
......@@ -68,21 +69,21 @@ PackedFunc ConvertSubgraphToPackedFunc(const Subgraph& sg) {
}
// Convert Sampled Subgraph structures to PackedFunc.
PackedFunc ConvertSubgraphToPackedFunc(const std::vector<SampledSubgraph>& sg) {
PackedFunc ConvertSubgraphToPackedFunc(const std::vector<NodeFlow>& sg) {
auto body = [sg] (DGLArgs args, DGLRetValue* rv) {
const int which = args[0];
const uint64_t which = args[0];
if (which < sg.size()) {
GraphInterface* gptr = sg[which].graph->Reset();
GraphHandle ghandle = gptr;
*rv = ghandle;
} else if (which >= sg.size() && which < sg.size() * 2) {
*rv = std::move(sg[which - sg.size()].induced_vertices);
*rv = std::move(sg[which - sg.size()].node_mapping);
} else if (which >= sg.size() * 2 && which < sg.size() * 3) {
*rv = std::move(sg[which - sg.size() * 2].induced_edges);
*rv = std::move(sg[which - sg.size() * 2].edge_mapping);
} else if (which >= sg.size() * 3 && which < sg.size() * 4) {
*rv = std::move(sg[which - sg.size() * 3].layer_ids);
*rv = std::move(sg[which - sg.size() * 3].layer_offsets);
} else if (which >= sg.size() * 4 && which < sg.size() * 5) {
*rv = std::move(sg[which - sg.size() * 4].sample_prob);
*rv = std::move(sg[which - sg.size() * 4].flow_offsets);
} else {
LOG(FATAL) << "invalid choice";
}
......@@ -446,10 +447,11 @@ void CAPI_NeighborUniformSample(DGLArgs args, DGLRetValue* rv) {
const ImmutableGraph *gptr = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(num_valid_seeds <= num_seeds);
std::vector<SampledSubgraph> subgs(seeds.size());
std::vector<NodeFlow> subgs(seeds.size());
#pragma omp parallel for
for (int i = 0; i < num_valid_seeds; i++) {
subgs[i] = gptr->NeighborUniformSample(seeds[i], neigh_type, num_hops, num_neighbors);
subgs[i] = SamplerOp::NeighborUniformSample(gptr, seeds[i],
neigh_type, num_hops, num_neighbors);
}
*rv = ConvertSubgraphToPackedFunc(subgs);
}
......
......@@ -9,10 +9,6 @@
#include <cmath>
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
int rand_r(unsigned *seed) {
return rand();
}
#define _CRT_RAND_S
#endif
......@@ -574,440 +570,4 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
}
}
////////////////////////////// Graph Sampling ///////////////////////////////
/*
* ArrayHeap is used to sample elements from vector
*/
class ArrayHeap {
public:
explicit ArrayHeap(const std::vector<float>& prob) {
vec_size_ = prob.size();
bit_len_ = ceil(log2(vec_size_));
limit_ = 1 << bit_len_;
// allocate twice the size
heap_.resize(limit_ << 1, 0);
// allocate the leaves
for (int i = limit_; i < vec_size_+limit_; ++i) {
heap_[i] = prob[i-limit_];
}
// iterate up the tree (this is O(m))
for (int i = bit_len_-1; i >= 0; --i) {
for (int j = (1 << i); j < (1 << (i + 1)); ++j) {
heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];
}
}
}
~ArrayHeap() {}
/*
* Remove term from index (this costs O(log m) steps)
*/
void Delete(size_t index) {
size_t i = index + limit_;
float w = heap_[i];
for (int j = bit_len_; j >= 0; --j) {
heap_[i] -= w;
i = i >> 1;
}
}
/*
* Add value w to index (this costs O(log m) steps)
*/
void Add(size_t index, float w) {
size_t i = index + limit_;
for (int j = bit_len_; j >= 0; --j) {
heap_[i] += w;
i = i >> 1;
}
}
/*
* Sample from arrayHeap
*/
size_t Sample(unsigned int* seed) {
float xi = heap_[1] * (rand_r(seed)%100/101.0);
int i = 1;
while (i < limit_) {
i = i << 1;
if (xi >= heap_[i]) {
xi -= heap_[i];
i += 1;
}
}
return i - limit_;
}
/*
* Sample a vector by given the size n
*/
void SampleWithoutReplacement(size_t n, std::vector<size_t>* samples, unsigned int* seed) {
// sample n elements
for (size_t i = 0; i < n; ++i) {
samples->at(i) = this->Sample(seed);
this->Delete(samples->at(i));
}
}
private:
int vec_size_; // sample size
int bit_len_; // bit size
int limit_;
std::vector<float> heap_;
};
/*
* Uniformly sample integers from [0, set_size) without replacement.
*/
static void RandomSample(size_t set_size,
size_t num,
std::vector<size_t>* out,
unsigned int* seed) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
sampled_idxs.insert(rand_r(seed) % set_size);
}
out->clear();
for (auto it = sampled_idxs.begin(); it != sampled_idxs.end(); it++) {
out->push_back(*it);
}
}
/*
* For a sparse array whose non-zeros are represented by nz_idxs,
* negate the sparse array and outputs the non-zeros in the negated array.
*/
static void NegateArray(const std::vector<size_t> &nz_idxs,
size_t arr_size,
std::vector<size_t>* out) {
// nz_idxs must have been sorted.
auto it = nz_idxs.begin();
size_t i = 0;
CHECK_GT(arr_size, nz_idxs.back());
for (; i < arr_size && it != nz_idxs.end(); i++) {
if (*it == i) {
it++;
continue;
}
out->push_back(i);
}
for (; i < arr_size; i++) {
out->push_back(i);
}
}
/*
* Uniform sample vertices from a list of vertices.
*/
static void GetUniformSample(const dgl_id_t* val_list,
const dgl_id_t* ver_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
// Copy ver_list to output
if (ver_len <= max_num_neighbor) {
for (size_t i = 0; i < ver_len; ++i) {
out_ver->push_back(ver_list[i]);
out_edge->push_back(val_list[i]);
}
return;
}
// If we just sample a small number of elements from a large neighbor list.
std::vector<size_t> sorted_idxs;
if (ver_len > max_num_neighbor * 2) {
sorted_idxs.reserve(max_num_neighbor);
RandomSample(ver_len, max_num_neighbor, &sorted_idxs, seed);
std::sort(sorted_idxs.begin(), sorted_idxs.end());
} else {
std::vector<size_t> negate;
negate.reserve(ver_len - max_num_neighbor);
RandomSample(ver_len, ver_len - max_num_neighbor,
&negate, seed);
std::sort(negate.begin(), negate.end());
NegateArray(negate, ver_len, &sorted_idxs);
}
// verify the result.
CHECK_EQ(sorted_idxs.size(), max_num_neighbor);
for (size_t i = 1; i < sorted_idxs.size(); i++) {
CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);
}
for (auto idx : sorted_idxs) {
out_ver->push_back(ver_list[idx]);
out_edge->push_back(val_list[idx]);
}
}
/*
* Non-uniform sample via ArrayHeap
*/
static void GetNonUniformSample(const float* probability,
const dgl_id_t* val_list,
const dgl_id_t* ver_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
// Copy ver_list to output
if (ver_len <= max_num_neighbor) {
for (size_t i = 0; i < ver_len; ++i) {
out_ver->push_back(ver_list[i]);
out_edge->push_back(val_list[i]);
}
return;
}
// Make sample
std::vector<size_t> sp_index(max_num_neighbor);
std::vector<float> sp_prob(ver_len);
for (size_t i = 0; i < ver_len; ++i) {
sp_prob[i] = probability[ver_list[i]];
}
ArrayHeap arrayHeap(sp_prob);
arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index, seed);
out_ver->resize(max_num_neighbor);
out_edge->resize(max_num_neighbor);
for (size_t i = 0; i < max_num_neighbor; ++i) {
size_t idx = sp_index[i];
out_ver->at(i) = ver_list[idx];
out_edge->at(i) = val_list[idx];
}
sort(out_ver->begin(), out_ver->end());
sort(out_edge->begin(), out_edge->end());
}
/*
* Used for subgraph sampling
*/
struct neigh_list {
std::vector<dgl_id_t> neighs;
std::vector<dgl_id_t> edges;
neigh_list(const std::vector<dgl_id_t> &_neighs,
const std::vector<dgl_id_t> &_edges)
: neighs(_neighs), edges(_edges) {}
};
SampledSubgraph ImmutableGraph::SampleSubgraph(IdArray seed_arr,
const float* probability,
const std::string &neigh_type,
int num_hops,
size_t num_neighbor) const {
unsigned int time_seed = time(nullptr);
size_t num_seeds = seed_arr->shape[0];
auto orig_csr = neigh_type == "in" ? GetInCSR() : GetOutCSR();
const dgl_id_t* val_list = orig_csr->edge_ids.data();
const dgl_id_t* col_list = orig_csr->indices.data();
const int64_t* indptr = orig_csr->indptr.data();
const dgl_id_t* seed = static_cast<dgl_id_t*>(seed_arr->data);
// BFS traverse the graph and sample vertices
// <vertex_id, layer_id>
std::unordered_set<dgl_id_t> sub_ver_map;
std::vector<std::pair<dgl_id_t, int> > sub_vers;
sub_vers.reserve(num_seeds * 10);
// add seed vertices
for (size_t i = 0; i < num_seeds; ++i) {
auto ret = sub_ver_map.insert(seed[i]);
// If the vertex is inserted successfully.
if (ret.second) {
sub_vers.emplace_back(seed[i], 0);
}
}
std::vector<dgl_id_t> tmp_sampled_src_list;
std::vector<dgl_id_t> tmp_sampled_edge_list;
// ver_id, position
std::vector<std::pair<dgl_id_t, size_t> > neigh_pos;
neigh_pos.reserve(num_seeds);
std::vector<dgl_id_t> neighbor_list;
int64_t num_edges = 0;
// sub_vers is used both as a node collection and a queue.
// In the while loop, we iterate over sub_vers and new nodes are added to the vector.
// A vertex in the vector only needs to be accessed once. If there is a vertex behind idx
// isn't in the last level, we will sample its neighbors. If not, the while loop terminates.
size_t idx = 0;
while (idx < sub_vers.size()) {
dgl_id_t dst_id = sub_vers[idx].first;
int cur_node_level = sub_vers[idx].second;
idx++;
// If the node is in the last level, we don't need to sample neighbors
// from this node.
if (cur_node_level >= num_hops)
continue;
tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear();
dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
if (probability == nullptr) { // uniform-sample
GetUniformSample(val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
} else { // non-uniform-sample
GetNonUniformSample(probability,
val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
}
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
size_t pos = neighbor_list.size();
neigh_pos.emplace_back(dst_id, pos);
// First we push the size of neighbor vector
neighbor_list.push_back(tmp_sampled_edge_list.size());
// Then push the vertices
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
neighbor_list.push_back(tmp_sampled_src_list[i]);
}
// Finally we push the edge list
for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {
neighbor_list.push_back(tmp_sampled_edge_list[i]);
}
num_edges += tmp_sampled_src_list.size();
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
// We need to add the neighbor in the hashtable here. This ensures that
// the vertex in the queue is unique. If we see a vertex before, we don't
// need to add it to the queue again.
auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]);
// If the sampled neighbor is inserted to the map successfully.
if (ret.second)
sub_vers.emplace_back(tmp_sampled_src_list[i], cur_node_level + 1);
}
}
// Let's check if there is a vertex that we haven't sampled its neighbors.
for (; idx < sub_vers.size(); idx++) {
if (sub_vers[idx].second < num_hops) {
LOG(WARNING)
<< "The sampling is truncated because we have reached the max number of vertices\n"
<< "Please use a smaller number of seeds or a small neighborhood";
break;
}
}
// Copy sub_ver_map to output[0]
// Copy layer
uint64_t num_vertices = sub_ver_map.size();
std::sort(sub_vers.begin(), sub_vers.end(),
[](const std::pair<dgl_id_t, dgl_id_t> &a1, const std::pair<dgl_id_t, dgl_id_t> &a2) {
return a1.first < a2.first;
});
SampledSubgraph subg;
subg.induced_vertices = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
subg.induced_edges = IdArray::Empty({static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
subg.layer_ids = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
subg.sample_prob = runtime::NDArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLFloat, 32, 1}, DLContext{kDLCPU, 0});
dgl_id_t *out = static_cast<dgl_id_t *>(subg.induced_vertices->data);
dgl_id_t *out_layer = static_cast<dgl_id_t *>(subg.layer_ids->data);
for (size_t i = 0; i < sub_vers.size(); i++) {
out[i] = sub_vers[i].first;
out_layer[i] = sub_vers[i].second;
}
// Copy sub_probability
float *sub_prob = static_cast<float *>(subg.sample_prob->data);
if (probability != nullptr) {
for (size_t i = 0; i < sub_ver_map.size(); ++i) {
dgl_id_t idx = out[i];
sub_prob[i] = probability[idx];
}
}
// Construct sub_csr_graph
auto subg_csr = std::make_shared<CSR>(num_vertices, num_edges);
subg_csr->indices.resize(num_edges);
subg_csr->edge_ids.resize(num_edges);
dgl_id_t* val_list_out = static_cast<dgl_id_t *>(subg.induced_edges->data);
dgl_id_t* col_list_out = subg_csr->indices.data();
int64_t* indptr_out = subg_csr->indptr.data();
size_t collected_nedges = 0;
// Both the out array and neigh_pos are sorted. By scanning the two arrays, we can see
// which vertices have neighbors and which don't.
std::sort(neigh_pos.begin(), neigh_pos.end(),
[](const std::pair<dgl_id_t, size_t> &a1, const std::pair<dgl_id_t, size_t> &a2) {
return a1.first < a2.first;
});
size_t idx_with_neigh = 0;
for (size_t i = 0; i < num_vertices; i++) {
dgl_id_t dst_id = *(out + i);
// If a vertex is in sub_ver_map but not in neigh_pos, this vertex must not
// have edges.
size_t edge_size = 0;
if (idx_with_neigh < neigh_pos.size() && dst_id == neigh_pos[idx_with_neigh].first) {
size_t pos = neigh_pos[idx_with_neigh].second;
CHECK_LT(pos, neighbor_list.size());
edge_size = neighbor_list[pos];
CHECK_LE(pos + edge_size * 2 + 1, neighbor_list.size());
std::copy_n(neighbor_list.begin() + pos + 1,
edge_size,
col_list_out + collected_nedges);
std::copy_n(neighbor_list.begin() + pos + edge_size + 1,
edge_size,
val_list_out + collected_nedges);
collected_nedges += edge_size;
idx_with_neigh++;
}
indptr_out[i+1] = indptr_out[i] + edge_size;
}
for (size_t i = 0; i < subg_csr->edge_ids.size(); i++)
subg_csr->edge_ids[i] = i;
if (neigh_type == "in")
subg.graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr, IsMultigraph()));
else
subg.graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr, IsMultigraph()));
return subg;
}
void CompactSubgraph(ImmutableGraph::CSR *subg,
const std::unordered_map<dgl_id_t, dgl_id_t> &id_map) {
for (size_t i = 0; i < subg->indices.size(); i++) {
auto it = id_map.find(subg->indices[i]);
CHECK(it != id_map.end());
subg->indices[i] = it->second;
}
}
void ImmutableGraph::CompactSubgraph(IdArray induced_vertices) {
// The key is the old id, the value is the id in the subgraph.
std::unordered_map<dgl_id_t, dgl_id_t> id_map;
const dgl_id_t *vdata = static_cast<dgl_id_t *>(induced_vertices->data);
size_t len = induced_vertices->shape[0];
for (size_t i = 0; i < len; i++)
id_map.insert(std::pair<dgl_id_t, dgl_id_t>(vdata[i], i));
if (in_csr_)
dgl::CompactSubgraph(in_csr_.get(), id_map);
if (out_csr_)
dgl::CompactSubgraph(out_csr_.get(), id_map);
}
SampledSubgraph ImmutableGraph::NeighborUniformSample(IdArray seeds,
const std::string &neigh_type,
int num_hops, int expand_factor) const {
auto ret = SampleSubgraph(seeds, // seed vector
nullptr, // sample_id_probability
neigh_type,
num_hops,
expand_factor);
std::static_pointer_cast<ImmutableGraph>(ret.graph)->CompactSubgraph(ret.induced_vertices);
return ret;
}
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file graph/sampler.cc
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dgl/immutable_graph.h>
#include <algorithm>
#ifdef _MSC_VER
// rand in MS compiler works well in multi-threading.
int rand_r(unsigned *seed) {
return rand();
}
#endif
namespace dgl {
namespace {
/*
* ArrayHeap is used to sample elements from vector
*/
class ArrayHeap {
public:
explicit ArrayHeap(const std::vector<float>& prob) {
vec_size_ = prob.size();
bit_len_ = ceil(log2(vec_size_));
limit_ = 1 << bit_len_;
// allocate twice the size
heap_.resize(limit_ << 1, 0);
// allocate the leaves
for (int i = limit_; i < vec_size_+limit_; ++i) {
heap_[i] = prob[i-limit_];
}
// iterate up the tree (this is O(m))
for (int i = bit_len_-1; i >= 0; --i) {
for (int j = (1 << i); j < (1 << (i + 1)); ++j) {
heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];
}
}
}
~ArrayHeap() {}
/*
* Remove term from index (this costs O(log m) steps)
*/
void Delete(size_t index) {
size_t i = index + limit_;
float w = heap_[i];
for (int j = bit_len_; j >= 0; --j) {
heap_[i] -= w;
i = i >> 1;
}
}
/*
* Add value w to index (this costs O(log m) steps)
*/
void Add(size_t index, float w) {
size_t i = index + limit_;
for (int j = bit_len_; j >= 0; --j) {
heap_[i] += w;
i = i >> 1;
}
}
/*
* Sample from arrayHeap
*/
size_t Sample(unsigned int* seed) {
float xi = heap_[1] * (rand_r(seed)%100/101.0);
int i = 1;
while (i < limit_) {
i = i << 1;
if (xi >= heap_[i]) {
xi -= heap_[i];
i += 1;
}
}
return i - limit_;
}
/*
* Sample a vector by given the size n
*/
void SampleWithoutReplacement(size_t n, std::vector<size_t>* samples, unsigned int* seed) {
// sample n elements
for (size_t i = 0; i < n; ++i) {
samples->at(i) = this->Sample(seed);
this->Delete(samples->at(i));
}
}
private:
int vec_size_; // sample size
int bit_len_; // bit size
int limit_;
std::vector<float> heap_;
};
/*
* Uniformly sample integers from [0, set_size) without replacement.
*/
void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out, unsigned int* seed) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
sampled_idxs.insert(rand_r(seed) % set_size);
}
out->clear();
out->insert(out->end(), sampled_idxs.begin(), sampled_idxs.end());
}
/*
* For a sparse array whose non-zeros are represented by nz_idxs,
* negate the sparse array and outputs the non-zeros in the negated array.
*/
void NegateArray(const std::vector<size_t> &nz_idxs,
size_t arr_size,
std::vector<size_t>* out) {
// nz_idxs must have been sorted.
auto it = nz_idxs.begin();
size_t i = 0;
CHECK_GT(arr_size, nz_idxs.back());
for (; i < arr_size && it != nz_idxs.end(); i++) {
if (*it == i) {
it++;
continue;
}
out->push_back(i);
}
for (; i < arr_size; i++) {
out->push_back(i);
}
}
/*
* Uniform sample vertices from a list of vertices.
*/
void GetUniformSample(const dgl_id_t* edge_id_list,
const dgl_id_t* vid_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
// Copy vid_list to output
if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
out_edge->insert(out_edge->end(), edge_id_list, edge_id_list + ver_len);
return;
}
// If we just sample a small number of elements from a large neighbor list.
std::vector<size_t> sorted_idxs;
if (ver_len > max_num_neighbor * 2) {
sorted_idxs.reserve(max_num_neighbor);
RandomSample(ver_len, max_num_neighbor, &sorted_idxs, seed);
std::sort(sorted_idxs.begin(), sorted_idxs.end());
} else {
std::vector<size_t> negate;
negate.reserve(ver_len - max_num_neighbor);
RandomSample(ver_len, ver_len - max_num_neighbor,
&negate, seed);
std::sort(negate.begin(), negate.end());
NegateArray(negate, ver_len, &sorted_idxs);
}
// verify the result.
CHECK_EQ(sorted_idxs.size(), max_num_neighbor);
for (size_t i = 1; i < sorted_idxs.size(); i++) {
CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);
}
for (auto idx : sorted_idxs) {
out_ver->push_back(vid_list[idx]);
out_edge->push_back(edge_id_list[idx]);
}
}
/*
* Non-uniform sample via ArrayHeap
*/
void GetNonUniformSample(const float* probability,
const dgl_id_t* edge_id_list,
const dgl_id_t* vid_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
// Copy vid_list to output
if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
out_edge->insert(out_edge->end(), edge_id_list, edge_id_list + ver_len);
return;
}
// Make sample
std::vector<size_t> sp_index(max_num_neighbor);
std::vector<float> sp_prob(ver_len);
for (size_t i = 0; i < ver_len; ++i) {
sp_prob[i] = probability[vid_list[i]];
}
ArrayHeap arrayHeap(sp_prob);
arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index, seed);
out_ver->resize(max_num_neighbor);
out_edge->resize(max_num_neighbor);
for (size_t i = 0; i < max_num_neighbor; ++i) {
size_t idx = sp_index[i];
out_ver->at(i) = vid_list[idx];
out_edge->at(i) = edge_id_list[idx];
}
sort(out_ver->begin(), out_ver->end());
sort(out_edge->begin(), out_edge->end());
}
/*
* Used for subgraph sampling
*/
struct neigh_list {
std::vector<dgl_id_t> neighs;
std::vector<dgl_id_t> edges;
neigh_list(const std::vector<dgl_id_t> &_neighs,
const std::vector<dgl_id_t> &_edges)
: neighs(_neighs), edges(_edges) {}
};
struct neighbor_info {
dgl_id_t id;
size_t pos;
size_t num_edges;
neighbor_info(dgl_id_t id, size_t pos, size_t num_edges) {
this->id = id;
this->pos = pos;
this->num_edges = num_edges;
}
};
NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<dgl_id_t> edge_list,
std::vector<size_t> layer_offsets,
std::vector<std::pair<dgl_id_t, int> > *sub_vers,
std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) {
NodeFlow nf;
uint64_t num_vertices = sub_vers->size();
nf.node_mapping = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.edge_mapping = IdArray::Empty({static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.layer_offsets = IdArray::Empty({static_cast<int64_t>(num_hops + 1)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
nf.flow_offsets = IdArray::Empty({static_cast<int64_t>(num_hops)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t *node_map_data = static_cast<dgl_id_t *>(nf.node_mapping->data);
dgl_id_t *layer_off_data = static_cast<dgl_id_t *>(nf.layer_offsets->data);
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf.flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf.edge_mapping->data);
// Construct sub_csr_graph
auto subg_csr = std::make_shared<ImmutableGraph::CSR>(num_vertices, num_edges);
subg_csr->indices.resize(num_edges);
subg_csr->edge_ids.resize(num_edges);
dgl_id_t* col_list_out = subg_csr->indices.data();
int64_t* indptr_out = subg_csr->indptr.data();
size_t collected_nedges = 0;
// The data from the previous steps:
// * node data: sub_vers (vid, layer), neigh_pos,
// * edge data: neighbor_list, edge_list, probability.
// * layer_offsets: the offset in sub_vers.
dgl_id_t ver_id = 0;
std::vector<std::unordered_map<dgl_id_t, dgl_id_t>> layer_ver_maps;
layer_ver_maps.resize(num_hops);
size_t out_node_idx = 0;
for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) {
// We sort the vertices in a layer so that we don't need to sort the neighbor Ids
// after remap to a subgraph.
std::sort(sub_vers->begin() + layer_offsets[layer_id],
sub_vers->begin() + layer_offsets[layer_id + 1],
[](const std::pair<dgl_id_t, dgl_id_t> &a1,
const std::pair<dgl_id_t, dgl_id_t> &a2) {
return a1.first < a2.first;
});
// Save the sampled vertices and its layer Id.
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
node_map_data[out_node_idx++] = sub_vers->at(i).first;
layer_ver_maps[layer_id].insert(std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first,
ver_id++));
assert(sub_vers->at(i).second == layer_id);
}
}
CHECK(out_node_idx == num_vertices);
// sampling algorithms have to start from the seed nodes, so the seed nodes are
// in the first layer and the input nodes are in the last layer.
// When we expose the sampled graph to a Python user, we say the input nodes
// are in the first layer and the seed nodes are in the last layer.
// Thus, when we copy sampled results to a CSR, we need to reverse the order of layers.
size_t row_idx = 0;
for (size_t i = layer_offsets[num_hops - 1]; i < layer_offsets[num_hops]; i++) {
indptr_out[row_idx++] = 0;
}
layer_off_data[0] = 0;
layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
size_t out_layer_idx = 1;
for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) {
std::sort(neigh_pos->begin() + layer_offsets[layer_id],
neigh_pos->begin() + layer_offsets[layer_id + 1],
[](const neighbor_info &a1, const neighbor_info &a2) {
return a1.id < a2.id;
});
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
dgl_id_t dst_id = sub_vers->at(i).first;
assert(dst_id == neigh_pos->at(i).id);
size_t pos = neigh_pos->at(i).pos;
CHECK_LT(pos, neighbor_list.size());
size_t num_edges = neigh_pos->at(i).num_edges;
// We need to map the Ids of the neighbors to the subgraph.
auto neigh_it = neighbor_list.begin() + pos;
for (size_t i = 0; i < num_edges; i++) {
dgl_id_t neigh = *(neigh_it + i);
assert(layer_ver_maps[layer_id + 1].find(neigh) != layer_ver_maps[layer_id + 1].end());
col_list_out[collected_nedges + i] = layer_ver_maps[layer_id + 1][neigh];
}
// We can simply copy the edge Ids.
std::copy_n(edge_list.begin() + pos,
num_edges, edge_map_data + collected_nedges);
collected_nedges += num_edges;
indptr_out[row_idx+1] = indptr_out[row_idx] + num_edges;
row_idx++;
}
layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx]
+ layer_offsets[layer_id + 1] - layer_offsets[layer_id];
out_layer_idx++;
}
CHECK(row_idx == num_vertices);
CHECK(indptr_out[row_idx] == num_edges);
CHECK(out_layer_idx == num_hops);
CHECK(layer_off_data[out_layer_idx] == num_vertices);
// Copy flow offsets.
flow_off_data[0] = 0;
size_t out_flow_idx = 0;
for (int i = 0; i < layer_offsets.size() - 2; i++) {
size_t num_edges = subg_csr->GetDegree(layer_off_data[i + 1], layer_off_data[i + 2]);
flow_off_data[out_flow_idx + 1] = flow_off_data[out_flow_idx] + num_edges;
out_flow_idx++;
}
CHECK(out_flow_idx == num_hops - 1);
CHECK(flow_off_data[num_hops - 1] == num_edges);
for (size_t i = 0; i < subg_csr->edge_ids.size(); i++) {
subg_csr->edge_ids[i] = i;
}
if (edge_type == "in") {
nf.graph = GraphPtr(new ImmutableGraph(subg_csr, nullptr, is_multigraph));
} else {
nf.graph = GraphPtr(new ImmutableGraph(nullptr, subg_csr, is_multigraph));
}
return nf;
}
NodeFlow SampleSubgraph(const ImmutableGraph *graph,
IdArray seed_arr,
const float* probability,
const std::string &edge_type,
int num_hops,
size_t num_neighbor) {
unsigned int time_seed = time(nullptr);
size_t num_seeds = seed_arr->shape[0];
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t* val_list = orig_csr->edge_ids.data();
const dgl_id_t* col_list = orig_csr->indices.data();
const int64_t* indptr = orig_csr->indptr.data();
const dgl_id_t* seed = static_cast<dgl_id_t*>(seed_arr->data);
std::unordered_set<dgl_id_t> sub_ver_map; // The vertex Ids in a layer.
std::vector<std::pair<dgl_id_t, int> > sub_vers;
sub_vers.reserve(num_seeds * 10);
// add seed vertices
for (size_t i = 0; i < num_seeds; ++i) {
auto ret = sub_ver_map.insert(seed[i]);
// If the vertex is inserted successfully.
if (ret.second) {
sub_vers.emplace_back(seed[i], 0);
}
}
std::vector<dgl_id_t> tmp_sampled_src_list;
std::vector<dgl_id_t> tmp_sampled_edge_list;
// ver_id, position
std::vector<neighbor_info> neigh_pos;
neigh_pos.reserve(num_seeds);
std::vector<dgl_id_t> neighbor_list;
std::vector<dgl_id_t> edge_list;
std::vector<size_t> layer_offsets(num_hops + 1);
int64_t num_edges = 0;
layer_offsets[0] = 0;
layer_offsets[1] = sub_vers.size();
for (size_t layer_id = 1; layer_id < num_hops; layer_id++) {
// We need to avoid resampling the same node in a layer, but we allow a node
// to be resampled in multiple layers. We use `sub_ver_map` to keep track of
// sampled nodes in a layer, and clear it when entering a new layer.
sub_ver_map.clear();
// Previous iteration collects all nodes in sub_vers, which are collected
// in the previous layer. sub_vers is used both as a node collection and a queue.
for (size_t idx = layer_offsets[layer_id - 1]; idx < layer_offsets[layer_id]; idx++) {
dgl_id_t dst_id = sub_vers[idx].first;
const int cur_node_level = sub_vers[idx].second;
tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear();
dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
if (probability == nullptr) { // uniform-sample
GetUniformSample(val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
} else { // non-uniform-sample
GetNonUniformSample(probability,
val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
}
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
// Then push the vertices
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
neighbor_list.push_back(tmp_sampled_src_list[i]);
}
// Finally we push the edge list
for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {
edge_list.push_back(tmp_sampled_edge_list[i]);
}
num_edges += tmp_sampled_src_list.size();
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
// We need to add the neighbor in the hashtable here. This ensures that
// the vertex in the queue is unique. If we see a vertex before, we don't
// need to add it to the queue again.
auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]);
// If the sampled neighbor is inserted to the map successfully.
if (ret.second) {
sub_vers.emplace_back(tmp_sampled_src_list[i], cur_node_level + 1);
}
}
}
layer_offsets[layer_id + 1] = layer_offsets[layer_id] + sub_ver_map.size();
CHECK_EQ(layer_offsets[layer_id + 1], sub_vers.size());
}
return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos,
edge_type, num_edges, num_hops, graph->IsMultigraph());
}
} // namespace anonymous
NodeFlow SamplerOp::NeighborUniformSample(const ImmutableGraph *graph, IdArray seeds,
const std::string &edge_type,
int num_hops, int expand_factor) {
return SampleSubgraph(graph,
seeds, // seed vector
nullptr, // sample_id_probability
edge_type,
num_hops + 1,
expand_factor);
}
} // namespace dgl
import backend as F
import numpy as np
import scipy as sp
import dgl
from dgl.node_flow import create_full_node_flow
from dgl import utils
import dgl.function as fn
from functools import partial
def generate_rand_graph(n, connect_more=False):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
# having one node to connect to all other nodes.
if connect_more:
arr[0] = 1
arr[:,0] = 1
g = dgl.DGLGraph(arr, readonly=True)
g.ndata['h1'] = F.randn((g.number_of_nodes(), 10))
g.edata['h2'] = F.randn((g.number_of_edges(), 3))
return g
def create_mini_batch(g, num_hops):
seed_ids = np.array([0, 1, 2, 3])
seed_ids = utils.toindex(seed_ids)
sgi = g._graph.neighbor_sampling([seed_ids], g.number_of_nodes(), num_hops, "in", None)
assert len(sgi) == 1
return dgl.node_flow.NodeFlow(g, sgi[0])
def check_basic(g, nf):
num_nodes = 0
for i in range(nf.num_layers):
num_nodes += nf.layer_size(i)
assert nf.number_of_nodes() == num_nodes
num_edges = 0
for i in range(nf.num_blocks):
num_edges += nf.block_size(i)
assert nf.number_of_edges() == num_edges
deg = nf.layer_in_degree(0)
assert F.array_equal(deg, F.zeros((nf.layer_size(0)), F.int64))
deg = nf.layer_out_degree(-1)
assert F.array_equal(deg, F.zeros((nf.layer_size(-1)), F.int64))
for i in range(1, nf.num_layers):
in_deg = nf.layer_in_degree(i)
out_deg = nf.layer_out_degree(i - 1)
assert F.asnumpy(F.sum(in_deg, 0) == F.sum(out_deg, 0))
def test_basic():
num_layers = 2
g = generate_rand_graph(100, connect_more=True)
nf = create_full_node_flow(g, num_layers)
assert nf.number_of_nodes() == g.number_of_nodes() * (num_layers + 1)
assert nf.number_of_edges() == g.number_of_edges() * num_layers
assert nf.num_layers == num_layers + 1
assert nf.layer_size(0) == g.number_of_nodes()
assert nf.layer_size(1) == g.number_of_nodes()
check_basic(g, nf)
parent_nids = F.arange(0, g.number_of_nodes())
nids = dgl.graph_index.map_to_nodeflow_nid(nf._graph, 0,
utils.toindex(parent_nids)).tousertensor()
assert F.array_equal(nids, parent_nids)
g = generate_rand_graph(100)
nf = create_mini_batch(g, num_layers)
assert nf.num_layers == num_layers + 1
check_basic(g, nf)
def check_apply_nodes(create_node_flow):
num_layers = 2
for i in range(num_layers):
g = generate_rand_graph(100)
nf = create_node_flow(g, num_layers)
nf.copy_from_parent()
new_feats = F.randn((nf.layer_size(i), 5))
def update_func(nodes):
return {'h1' : new_feats}
nf.apply_layer(i, update_func)
assert F.array_equal(nf.layers[i].data['h1'], new_feats)
new_feats = F.randn((4, 5))
def update_func1(nodes):
return {'h1' : new_feats}
nf.apply_layer(i, update_func1, v=nf.layer_nid(i)[0:4])
assert F.array_equal(nf.layers[i].data['h1'][0:4], new_feats)
def test_apply_nodes():
check_apply_nodes(create_full_node_flow)
check_apply_nodes(create_mini_batch)
def check_apply_edges(create_node_flow):
num_layers = 2
for i in range(num_layers):
g = generate_rand_graph(100)
nf = create_node_flow(g, num_layers)
nf.copy_from_parent()
new_feats = F.randn((nf.block_size(i), 5))
def update_func(nodes):
return {'h2' : new_feats}
nf.apply_block(i, update_func)
assert F.array_equal(nf.blocks[i].data['h2'], new_feats)
def test_apply_edges():
check_apply_edges(create_full_node_flow)
check_apply_edges(create_mini_batch)
def check_flow_compute(create_node_flow):
num_layers = 2
g = generate_rand_graph(100)
nf = create_node_flow(g, num_layers)
nf.copy_from_parent()
g.ndata['h'] = g.ndata['h1']
nf.layers[0].data['h'] = nf.layers[0].data['h1']
# Test the computation on a layer at a time.
for i in range(num_layers):
nf.block_compute(i, fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
assert F.array_equal(nf.layers[i + 1].data['h'], g.ndata['h'][nf.layer_parent_nid(i + 1)])
# Test the computation when only a few nodes are active in a layer.
g.ndata['h'] = g.ndata['h1']
for i in range(num_layers):
vs = nf.layer_nid(i+1)[0:4]
nf.block_compute(i, fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1}, v=vs)
g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
data1 = nf.layers[i + 1].data['h'][0:4]
data2 = g.ndata['h'][nf.map_to_parent_nid(vs)]
assert F.array_equal(data1, data2)
def test_flow_compute():
check_flow_compute(create_full_node_flow)
check_flow_compute(create_mini_batch)
def check_prop_flows(create_node_flow):
num_layers = 2
g = generate_rand_graph(100)
g.ndata['h'] = g.ndata['h1']
nf2 = create_node_flow(g, num_layers)
nf2.copy_from_parent()
# Test the computation on a layer at a time.
for i in range(num_layers):
g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
# Test the computation on all layers.
nf2.prop_flow(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
assert F.array_equal(nf2.layers[-1].data['h'], g.ndata['h'][nf2.layer_parent_nid(-1)])
def test_prop_flows():
check_prop_flows(create_full_node_flow)
check_prop_flows(create_mini_batch)
def test_copy():
num_layers = 2
g = generate_rand_graph(100)
g.ndata['h'] = g.ndata['h1']
nf = create_mini_batch(g, num_layers)
nf.copy_from_parent()
for i in range(nf.num_layers):
assert len(g.ndata.keys()) == len(nf.layers[i].data.keys())
for key in g.ndata.keys():
assert key in nf.layers[i].data.keys()
assert F.array_equal(nf.layers[i].data[key], g.ndata[key][nf.layer_parent_nid(i)])
for i in range(nf.num_blocks):
assert len(g.edata.keys()) == len(nf.blocks[i].data.keys())
for key in g.edata.keys():
assert key in nf.blocks[i].data.keys()
assert F.array_equal(nf.blocks[i].data[key], g.edata[key][nf.block_parent_eid(i)])
nf = create_mini_batch(g, num_layers)
node_embed_names = [['h'], ['h1'], ['h']]
edge_embed_names = [['h2'], ['h2']]
nf.copy_from_parent(node_embed_names=node_embed_names, edge_embed_names=edge_embed_names)
for i in range(nf.num_layers):
assert len(node_embed_names[i]) == len(nf.layers[i].data.keys())
for key in node_embed_names[i]:
assert key in nf.layers[i].data.keys()
assert F.array_equal(nf.layers[i].data[key], g.ndata[key][nf.layer_parent_nid(i)])
for i in range(nf.num_blocks):
assert len(edge_embed_names[i]) == len(nf.blocks[i].data.keys())
for key in edge_embed_names[i]:
assert key in nf.blocks[i].data.keys()
assert F.array_equal(nf.blocks[i].data[key], g.edata[key][nf.block_parent_eid(i)])
nf = create_mini_batch(g, num_layers)
g.ndata['h0'] = F.clone(g.ndata['h'])
node_embed_names = [['h0'], [], []]
nf.copy_from_parent(node_embed_names=node_embed_names, edge_embed_names=None)
for i in range(num_layers):
nf.block_compute(i, fn.copy_src(src='h%d' % i, out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h%d' % (i+1) : nodes.data['t'] + 1})
g.update_all(fn.copy_src(src='h', out='m'), fn.sum(msg='m', out='t'),
lambda nodes: {'h' : nodes.data['t'] + 1})
assert F.array_equal(nf.layers[i + 1].data['h%d' % (i+1)],
g.ndata['h'][nf.layer_parent_nid(i + 1)])
nf.copy_to_parent(node_embed_names=[['h0'], ['h1'], ['h2']])
for i in range(num_layers + 1):
assert F.array_equal(nf.layers[i].data['h%d' % i],
g.ndata['h%d' % i][nf.layer_parent_nid(i)])
nf = create_mini_batch(g, num_layers)
g.ndata['h0'] = F.clone(g.ndata['h'])
g.ndata['h1'] = F.clone(g.ndata['h'])
g.ndata['h2'] = F.clone(g.ndata['h'])
node_embed_names = [['h0'], ['h1'], ['h2']]
nf.copy_from_parent(node_embed_names=node_embed_names, edge_embed_names=None)
def msg_func(edge, ind):
assert 'h%d' % ind in edge.src.keys()
return {'m' : edge.src['h%d' % ind]}
def reduce_func(node, ind):
assert 'h%d' % (ind + 1) in node.data.keys()
return {'h' : F.sum(node.mailbox['m'], 1) + node.data['h%d' % (ind + 1)]}
for i in range(num_layers):
nf.block_compute(i, partial(msg_func, ind=i), partial(reduce_func, ind=i))
if __name__ == '__main__':
test_basic()
test_copy()
test_apply_nodes()
test_apply_edges()
test_flow_compute()
test_prop_flows()
......@@ -16,37 +16,36 @@ def test_1neighbor_sampler_all():
seed_ids = aux['seeds']
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
# Test if there is a self loop
self_loop = F.asnumpy(F.sum(src == dst, 0)) == 1
if self_loop:
assert subg.number_of_nodes() == len(src)
else:
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() >= len(src)
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() == len(src)
child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
assert seed_ids == subg.layer_parent_nid(-1)
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
assert F.array_equal(child_src, subg.layer_nid(0))
child_src1 = subg.map_to_subgraph_nid(src)
assert F.asnumpy(F.sum(child_src1 == child_src, 0)) == len(src)
src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src)
def is_sorted(arr):
return np.sum(np.sort(arr) == arr, 0) == len(arr)
def verify_subgraph(g, subg, seed_id):
seed_id = F.asnumpy(seed_id)
seeds = F.asnumpy(subg.map_to_parent_nid(subg.layer_nid(-1)))
assert seed_id in seeds
child_seed = F.asnumpy(subg.layer_nid(-1))[seeds == seed_id]
src, dst, eid = g.in_edges(seed_id, form='all')
child_id = subg.map_to_subgraph_nid(seed_id)
child_src, child_dst, child_eid = subg.in_edges(child_id, form='all')
child_src, child_dst, child_eid = subg.in_edges(child_seed, form='all')
child_src = F.asnumpy(child_src)
# We don't allow duplicate elements in the neighbor list.
assert(len(np.unique(child_src)) == len(child_src))
# The neighbor list also needs to be sorted.
assert(is_sorted(child_src))
child_src1 = F.asnumpy(subg.map_to_subgraph_nid(src))
child_src1 = child_src1[child_src1 >= 0]
for i in child_src:
assert i in child_src1
# a neighbor in the subgraph must also exist in parent graph.
for i in subg.map_to_parent_nid(child_src):
assert i in src
def test_1neighbor_sampler():
g = generate_rand_graph(100)
......@@ -76,13 +75,12 @@ def test_10neighbor_sampler_all():
for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
num_workers=4, return_seed_id=True):
seed_ids = aux['seeds']
src, dst, eid = g.in_edges(seed_ids, form='all')
assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1)))
child_ids = subg.map_to_subgraph_nid(seed_ids)
child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
child_src1 = subg.map_to_subgraph_nid(src)
assert F.asnumpy(F.sum(child_src1 == child_src, 0)) == len(src)
src, dst, eid = g.in_edges(seed_ids, form='all')
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src)
def check_10neighbor_sampler(g, seeds):
# In this case, NeighborSampling simply gets the neighborhood of a single vertex.
......
......@@ -202,9 +202,9 @@ while True:
#
import mxnet.gluon as gluon
class SteadyStateOperator(gluon.Block):
class FullGraphSteadyStateOperator(gluon.Block):
def __init__(self, n_hidden, activation, **kwargs):
super(SteadyStateOperator, self).__init__(**kwargs)
super(FullGraphSteadyStateOperator, self).__init__(**kwargs)
with self.name_scope():
self.dense1 = gluon.nn.Dense(n_hidden, activation=activation)
self.dense2 = gluon.nn.Dense(n_hidden)
......@@ -249,9 +249,9 @@ class Predictor(gluon.Block):
self.dense1 = gluon.nn.Dense(n_hidden, activation=activation)
self.dense2 = gluon.nn.Dense(2) ## binary classifier
def forward(self, g):
g.ndata['z'] = self.dense2(self.dense1(g.ndata['h']))
return g.ndata['z']
def forward(self, h):
return self.dense2(self.dense1(h))
##############################################################################
# The predictor’s decision rule is just a decision rule for binary
# classification:
......@@ -339,11 +339,11 @@ nodes_test = np.where(test_bitmap)[0]
# :math:`\call_\text{SteadyState}`. Note that ``g`` in the following is
# :math:`\calg_y` instead of :math:`\calg`.
#
def update_parameters(g, label_nodes, steady_state_operator, predictor, trainer):
def fullgraph_update_parameters(g, label_nodes, steady_state_operator, predictor, trainer):
n = g.number_of_nodes()
with mx.autograd.record():
steady_state_operator(g)
z = predictor(g)[label_nodes]
z = predictor(g.ndata['h'][label_nodes])
y = g.ndata['y'].reshape(n)[label_nodes] # label
loss = mx.nd.softmax_cross_entropy(z, y)
loss.backward()
......@@ -369,7 +369,8 @@ def train(g, label_nodes, steady_state_operator, predictor, trainer):
update_embeddings(g, steady_state_operator)
# second phase
for i in range(n_parameter_updates):
loss = update_parameters(g, label_nodes, steady_state_operator, predictor, trainer)
loss = fullgraph_update_parameters(g, label_nodes, steady_state_operator,
predictor, trainer)
return loss
##############################################################################
# Scaling up with Stochastic Subgraph Training
......@@ -417,18 +418,15 @@ def train(g, label_nodes, steady_state_operator, predictor, trainer):
# at a time with neighbor sampling.
#
# The following code demonstrates how to use the ``NeighborSampler`` to
# sample subgraphs, and stores the nodes and edges of the subgraph, as
# well as seed nodes in each iteration:
# sample subgraphs, and stores the seed nodes of the subgraph in each iteration:
#
nx_G = nx.erdos_renyi_graph(36, 0.06)
G = dgl.DGLGraph(nx_G.to_directed(), readonly=True)
sampler = dgl.contrib.sampling.NeighborSampler(
G, 2, 3, num_hops=2, shuffle=True)
nid = []
eid = []
seeds = []
for subg, aux_info in sampler:
nid.append(subg.parent_nid.asnumpy())
eid.append(subg.parent_eid.asnumpy())
seeds.append(subg.layer_parent_nid(-1))
##############################################################################
# Sampler with DGL
......@@ -436,13 +434,46 @@ for subg, aux_info in sampler:
#
# The code illustrates the training process in mini-batches.
#
def update_embeddings_subgraph(g, seed_nodes, steady_state_operator):
class SubgraphSteadyStateOperator(gluon.Block):
def __init__(self, n_hidden, activation, **kwargs):
super(SubgraphSteadyStateOperator, self).__init__(**kwargs)
with self.name_scope():
self.dense1 = gluon.nn.Dense(n_hidden, activation=activation)
self.dense2 = gluon.nn.Dense(n_hidden)
def forward(self, subg):
def message_func(edges):
x = edges.src['x']
h = edges.src['h']
return {'m' : mx.nd.concat(x, h, dim=1)}
def reduce_func(nodes):
m = mx.nd.sum(nodes.mailbox['m'], axis=1)
z = mx.nd.concat(nodes.data['x'], m, dim=1)
return {'h' : self.dense2(self.dense1(z))}
subg.block_compute(0, message_func, reduce_func)
return subg.layers[-1].data['h']
def update_parameters_subgraph(subg, steady_state_operator, predictor, trainer):
n = subg.layer_size(-1)
with mx.autograd.record():
steady_state_operator(subg)
z = predictor(subg.layers[-1].data['h'])
y = subg.layers[-1].data['y'].reshape(n) # label
loss = mx.nd.softmax_cross_entropy(z, y)
loss.backward()
trainer.step(n) # divide gradients by the number of labelled nodes
return loss.asnumpy()[0]
def update_embeddings_subgraph(g, steady_state_operator):
# Note that we are only updating the embeddings of seed nodes here.
# The reason is that only the seed nodes have ample information
# from neighbors, especially if the subgraph is small (e.g. 1-hops)
prev_h = g.ndata['h'][seed_nodes]
next_h = steady_state_operator(g)[seed_nodes]
g.ndata['h'][seed_nodes] = (1 - alpha) * prev_h + alpha * next_h
prev_h = g.layers[-1].data['h']
next_h = steady_state_operator(g)
g.layers[-1].data['h'] = (1 - alpha) * prev_h + alpha * next_h
def train_on_subgraphs(g, label_nodes, batch_size,
steady_state_operator, predictor, trainer):
......@@ -459,40 +490,34 @@ def train_on_subgraphs(g, label_nodes, batch_size,
return_seed_id=True)
for i in range(n_embedding_updates):
subg, aux_info = next(sampler)
seeds = aux_info['seeds']
# Currently, subgraphing does not copy or share features
# automatically. Therefore, we need to copy the node
# embeddings of the subgraph from the parent graph with
# `copy_from_parent()` before computing...
subg.copy_from_parent()
subg_seeds = subg.map_to_subgraph_nid(seeds)
update_embeddings_subgraph(subg, subg_seeds, steady_state_operator)
# ... and copy them back to the parent graph with
# `copy_to_parent()` afterwards.
subg.copy_to_parent()
update_embeddings_subgraph(subg, steady_state_operator)
# ... and copy them back to the parent graph.
g.ndata['h'][subg.layer_parent_nid(-1)] = subg.layers[-1].data['h']
for i in range(n_parameter_updates):
try:
subg, aux_info = next(sampler_train)
seeds = aux_info['seeds']
except:
break
# Again we need to copy features from parent graph
subg.copy_from_parent()
subg_seeds = subg.map_to_subgraph_nid(seeds)
loss = update_parameters(subg, subg_seeds,
steady_state_operator, predictor, trainer)
loss = update_parameters_subgraph(subg, steady_state_operator, predictor, trainer)
# We don't need to copy the features back to parent graph.
return loss
##############################################################################
# We also define a helper function that reports prediction accuracy:
def test(g, test_nodes, steady_state_operator, predictor):
predictor(g)
y_bar = mx.nd.argmax(g.ndata['z'], axis=1)[test_nodes]
def test(g, test_nodes, predictor):
z = predictor(g.ndata['h'][test_nodes])
y_bar = mx.nd.argmax(z, axis=1)
y = g.ndata['y'].reshape(n)[test_nodes]
accuracy = mx.nd.sum(y_bar == y) / len(test_nodes)
return accuracy.asnumpy()[0]
return accuracy.asnumpy()[0], z
##############################################################################
# Some routine preparations for training:
......@@ -500,11 +525,11 @@ def test(g, test_nodes, steady_state_operator, predictor):
lr = 1e-3
activation = 'relu'
steady_state_operator = SteadyStateOperator(n_hidden, activation)
subgraph_steady_state_operator = SubgraphSteadyStateOperator(n_hidden, activation)
predictor = Predictor(n_hidden, activation)
steady_state_operator.initialize()
subgraph_steady_state_operator.initialize()
predictor.initialize()
params = steady_state_operator.collect_params()
params = subgraph_steady_state_operator.collect_params()
params.update(predictor.collect_params())
trainer = gluon.Trainer(params, 'adam', {'learning_rate' : lr})
......@@ -520,12 +545,13 @@ batch_size = 64
y_bars = []
for i in range(n_epochs):
loss = train_on_subgraphs(g, nodes_train, batch_size, steady_state_operator, predictor, trainer)
loss = train_on_subgraphs(g, nodes_train, batch_size, subgraph_steady_state_operator,
predictor, trainer)
accuracy_train = test(g, nodes_train, steady_state_operator, predictor)
accuracy_test = test(g, nodes_test, steady_state_operator, predictor)
accuracy_train, _ = test(g, nodes_train, predictor)
accuracy_test, z = test(g, nodes_test, predictor)
print("Iter {:05d} | Train acc {:.4} | Test acc {:.4f}".format(i, accuracy_train, accuracy_test))
y_bar = mx.nd.argmax(g.ndata['z'], axis=1)
y_bar = mx.nd.argmax(z, axis=1)
y_bars.append(y_bar)
##############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment