Commit 52d4535b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Minjie Wang
Browse files

[Hetero][RFC] Heterogeneous graph Python interfaces & Message Passing (#752)

* moving heterograph index to another file

* node view

* python interfaces

* heterograph init

* bug fixes

* docstring for readonly

* more docstring

* unit tests & lint

* oops

* oops x2

* removed node/edge addition

* addressed comments

* lint

* rw on frames with one node/edge type

* homograph with underlying heterograph demo

* view is not necessary

* bugfix

* replace

* scheduler, builtins not working yet

* moving bipartite.h to header

* moving back bipartite to bipartite.h

* oops

* asbits and copyto for bipartite

* tested update_all and send_and_recv

* lightweight node & edge type retrieval

* oops

* sorry

* removing obsolete code

* oops

* lint

* various bug fixes & more tests

* UDF tests

* multiple type number_of_nodes and number_of_edges

* docstring fixes

* more tests

* going for dict in initialization

* lint

* updated api as per discussions

* lint

* bug

* bugfix

* moving back bipartite impl to cc

* note on views

* fix
parent 66971c1a
......@@ -116,6 +116,11 @@ IdArray IndexSelect(IdArray array, IdArray index);
*/
IdArray Relabel_(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
//////////////////////////////////////////////////////////////////////
// Sparse matrix
//////////////////////////////////////////////////////////////////////
......
......@@ -17,6 +17,7 @@ from .base import ALL
from .backend import load_backend
from .batched_graph import *
from .graph import DGLGraph
from .heterograph import DGLHeteroGraph
from .nodeflow import *
from .traversal import *
from .transform import *
......
......@@ -49,6 +49,14 @@ class DGLBaseGraph(object):
"""
return self._graph.number_of_nodes()
def _number_of_src_nodes(self):
"""Return number of source nodes (only used in scheduler)"""
return self.number_of_nodes()
def _number_of_dst_nodes(self):
"""Return number of destination nodes (only used in scheduler)"""
return self.number_of_nodes()
def __len__(self):
"""Return the number of nodes in the graph."""
return self.number_of_nodes()
......@@ -65,6 +73,10 @@ class DGLBaseGraph(object):
"""
return self._graph.is_readonly()
def _number_of_edges(self):
"""Return number of edges in the current view (only used for scheduler)"""
return self.number_of_edges()
def number_of_edges(self):
"""Return the number of edges in the graph.
......@@ -939,6 +951,14 @@ class DGLGraph(DGLBaseGraph):
def _set_msg_index(self, index):
self._msg_index = index
@property
def _src_frame(self):
return self._node_frame
@property
def _dst_frame(self):
return self._node_frame
def add_nodes(self, num, data=None):
"""Add multiple new nodes.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -161,7 +161,8 @@ def gen_group_apply_edge_schedule(
apply_func,
u, v, eid,
group_by,
var_nf,
var_src_nf,
var_dst_nf,
var_ef,
var_out):
"""Create degree bucketing schedule for group_apply_edge
......@@ -186,8 +187,10 @@ def gen_group_apply_edge_schedule(
Edges to apply
group_by: str
If "src", group by u. If "dst", group by v
var_nf : var.FEAT_DICT
The variable for node feature frame.
var_src_nf : var.FEAT_DICT
The variable for source feature frame.
var_dst_nf : var.FEAT_DICT
The variable for destination feature frame.
var_ef : var.FEAT_DICT
The variable for edge frame.
var_out : var.FEAT_DICT
......@@ -213,8 +216,8 @@ def gen_group_apply_edge_schedule(
var_v = var.IDX(v_bkt)
var_eid = var.IDX(eid_bkt)
# apply edge UDF on each bucket
fdsrc = ir.READ_ROW(var_nf, var_u)
fddst = ir.READ_ROW(var_nf, var_v)
fdsrc = ir.READ_ROW(var_src_nf, var_u)
fddst = ir.READ_ROW(var_dst_nf, var_v)
fdedge = ir.READ_ROW(var_ef, var_eid)
fdedge = ir.EDGE_UDF(_efunc, fdsrc, fdedge, fddst, ret=fdedge) # reuse var
# save for merge
......
......@@ -8,6 +8,8 @@ from .. import backend as F
from ..frame import frame_like, FrameRef
from ..function.base import BuiltinFunction
from ..udf import EdgeBatch, NodeBatch
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex
from . import ir
from .ir import var
......@@ -28,6 +30,15 @@ __all__ = [
"schedule_pull"
]
def _dispatch(graph, method, *args, **kwargs):
graph_index = graph._graph
if isinstance(graph_index, GraphIndex):
return getattr(graph._graph, method)(*args, **kwargs)
elif isinstance(graph_index, HeteroGraphIndex):
return getattr(graph._graph, method)(graph._current_etype_idx, *args, **kwargs)
else:
raise TypeError('unknown type %s' % type(graph_index))
def schedule_send(graph, u, v, eid, message_func):
"""get send schedule
......@@ -45,7 +56,8 @@ def schedule_send(graph, u, v, eid, message_func):
The message function
"""
var_mf = var.FEAT_DICT(graph._msg_frame)
var_nf = var.FEAT_DICT(graph._node_frame)
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_eid = var.IDX(eid)
......@@ -54,8 +66,8 @@ def schedule_send(graph, u, v, eid, message_func):
v=v,
eid=eid,
mfunc=message_func,
var_src_nf=var_nf,
var_dst_nf=var_nf,
var_src_nf=var_src_nf,
var_dst_nf=var_dst_nf,
var_ef=var_ef)
# write tmp msg back
......@@ -83,7 +95,7 @@ def schedule_recv(graph,
inplace: bool
If True, the update will be done in place
"""
src, dst, eid = graph._graph.in_edges(recv_nodes)
src, dst, eid = _dispatch(graph, 'in_edges', recv_nodes)
if len(eid) > 0:
nonzero_idx = graph._get_msg_index().get_items(eid).nonzero()
eid = eid.get_items(nonzero_idx)
......@@ -96,7 +108,7 @@ def schedule_recv(graph,
if apply_func is not None:
schedule_apply_nodes(graph, recv_nodes, apply_func, inplace)
else:
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
# sort and unique the argument
recv_nodes, _ = F.sort_1d(F.unique(recv_nodes.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
......@@ -105,12 +117,12 @@ def schedule_recv(graph,
reduced_feat = _gen_reduce(graph, reduce_func, (src, dst, eid),
recv_nodes)
# apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
# set message indicator to 0
graph._set_msg_index(graph._get_msg_index().set_items(eid, 0))
if not graph._get_msg_index().has_nonzero():
......@@ -148,7 +160,7 @@ def schedule_snr(graph,
recv_nodes, _ = F.sort_1d(F.unique(v.tousertensor()))
recv_nodes = utils.toindex(recv_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
......@@ -156,11 +168,11 @@ def schedule_snr(graph,
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
edge_tuples, graph.number_of_nodes())
edge_tuples, graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(recv_nodes, nbits)
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame,
dst_node_frame=graph._node_frame,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
message_func=message_func,
reduce_func=reduce_func,
......@@ -170,12 +182,12 @@ def schedule_snr(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate apply schedule
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf, reduced_feat,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf, reduced_feat,
apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_recv_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_recv_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_recv_nodes, final_feat)
def schedule_update_all(graph,
message_func,
......@@ -194,27 +206,27 @@ def schedule_update_all(graph,
apply_func: callable
The apply node function
"""
if graph.number_of_edges() == 0:
if graph._number_of_edges() == 0:
# All the nodes are zero degree; downgrade to apply nodes
if apply_func is not None:
nodes = utils.toindex(slice(0, graph.number_of_nodes()))
nodes = utils.toindex(slice(0, graph._number_of_dst_nodes()))
schedule_apply_nodes(graph, nodes, apply_func, inplace=False)
else:
eid = utils.toindex(slice(0, graph.number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph.number_of_nodes())) # ALL
eid = utils.toindex(slice(0, graph._number_of_edges())) # ALL
recv_nodes = utils.toindex(slice(0, graph._number_of_dst_nodes())) # ALL
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_recv_nodes = var.IDX(recv_nodes, name='recv_nodes')
var_eid = var.IDX(eid)
# generate send + reduce
def uv_getter():
src, dst, _ = graph._graph.edges('eid')
src, dst, _ = _dispatch(graph, 'edges', 'eid')
return var.IDX(src), var.IDX(dst)
adj_creator = lambda: spmv.build_gidx_and_mapping_graph(graph)
out_map_creator = lambda nbits: None
reduced_feat = _gen_send_reduce(graph=graph,
src_node_frame=graph._node_frame,
dst_node_frame=graph._node_frame,
src_node_frame=graph._src_frame,
dst_node_frame=graph._dst_frame,
edge_frame=graph._edge_frame,
message_func=message_func,
reduce_func=reduce_func,
......@@ -224,9 +236,9 @@ def schedule_update_all(graph,
adj_creator=adj_creator,
out_map_creator=out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_recv_nodes, var_nf,
final_feat = _apply_with_accum(graph, var_recv_nodes, var_dst_nf,
reduced_feat, apply_func)
ir.WRITE_DICT_(var_nf, final_feat)
ir.WRITE_DICT_(var_dst_nf, final_feat)
def schedule_apply_nodes(graph,
v,
......@@ -326,10 +338,12 @@ def schedule_apply_edges(graph,
A list of executors for DGL Runtime
"""
# vars
var_nf = var.FEAT_DICT(graph._node_frame)
var_src_nf = var.FEAT_DICT(graph._src_frame)
var_dst_nf = var.FEAT_DICT(graph._dst_frame)
var_ef = var.FEAT_DICT(graph._edge_frame)
var_out = _gen_send(graph=graph, u=u, v=v, eid=eid, mfunc=apply_func,
var_src_nf=var_nf, var_dst_nf=var_nf, var_ef=var_ef)
var_src_nf=var_src_nf, var_dst_nf=var_dst_nf,
var_ef=var_ef)
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_eid = var.IDX(eid)
# schedule apply edges
......@@ -401,7 +415,7 @@ def schedule_push(graph,
inplace: bool
If True, the update will be done in place
"""
u, v, eid = graph._graph.out_edges(u)
u, v, eid = _dispatch(graph, 'out_edges', u)
if len(eid) == 0:
# All the pushing nodes have no out edges. No computation is scheduled.
return
......@@ -434,7 +448,7 @@ def schedule_pull(graph,
# TODO(minjie): `in_edges` can be omitted if message and reduce func pairs
# can be specialized to SPMV. This needs support for creating adjmat
# directly from pull node frontier.
u, v, eid = graph._graph.in_edges(pull_nodes)
u, v, eid = _dispatch(graph, 'in_edges', pull_nodes)
if len(eid) == 0:
# All the nodes are 0deg; downgrades to apply.
if apply_func is not None:
......@@ -443,27 +457,27 @@ def schedule_pull(graph,
pull_nodes, _ = F.sort_1d(F.unique(pull_nodes.tousertensor()))
pull_nodes = utils.toindex(pull_nodes)
# create vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='nf')
var_pull_nodes = var.IDX(pull_nodes, name='pull_nodes')
var_u = var.IDX(u)
var_v = var.IDX(v)
var_eid = var.IDX(eid)
# generate send and reduce schedule
uv_getter = lambda: (var_u, var_v)
num_nodes = graph.number_of_nodes()
adj_creator = lambda: spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes)
adj_creator = lambda: spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
out_map_creator = lambda nbits: _build_idx_map(pull_nodes, nbits)
reduced_feat = _gen_send_reduce(graph, graph._node_frame,
graph._node_frame, graph._edge_frame,
reduced_feat = _gen_send_reduce(graph, graph._src_frame,
graph._dst_frame, graph._edge_frame,
message_func, reduce_func, var_eid,
var_pull_nodes, uv_getter, adj_creator,
out_map_creator)
# generate optional apply
final_feat = _apply_with_accum(graph, var_pull_nodes, var_nf, reduced_feat, apply_func)
final_feat = _apply_with_accum(graph, var_pull_nodes, var_dst_nf, reduced_feat, apply_func)
if inplace:
ir.WRITE_ROW_INPLACE_(var_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_INPLACE_(var_dst_nf, var_pull_nodes, final_feat)
else:
ir.WRITE_ROW_(var_nf, var_pull_nodes, final_feat)
ir.WRITE_ROW_(var_dst_nf, var_pull_nodes, final_feat)
def schedule_group_apply_edge(graph,
u, v, eid,
......@@ -494,11 +508,12 @@ def schedule_group_apply_edge(graph,
A list of executors for DGL Runtime
"""
# vars
var_nf = var.FEAT_DICT(graph._node_frame, name='nf')
var_src_nf = var.FEAT_DICT(graph._src_frame, name='src_nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, name='dst_nf')
var_ef = var.FEAT_DICT(graph._edge_frame, name='ef')
var_out = var.FEAT_DICT(name='new_ef')
db.gen_group_apply_edge_schedule(graph, apply_func, u, v, eid, group_by,
var_nf, var_ef, var_out)
var_src_nf, var_dst_nf, var_ef, var_out)
var_eid = var.IDX(eid)
if inplace:
ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out)
......@@ -719,17 +734,16 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
# node frame.
# TODO(minjie): should replace this with an IR call to make the program
# stateless.
tmpframe = FrameRef(frame_like(graph._node_frame._frame, len(recv_nodes)))
tmpframe = FrameRef(frame_like(graph._dst_frame._frame, len(recv_nodes)))
# vars
var_msg = var.FEAT_DICT(graph._msg_frame, 'msg')
var_nf = var.FEAT_DICT(graph._node_frame, 'nf')
var_dst_nf = var.FEAT_DICT(graph._dst_frame, 'nf')
var_out = var.FEAT_DICT(data=tmpframe)
if rfunc_is_list:
num_nodes = graph.number_of_nodes()
adj, edge_map, nbits = spmv.build_gidx_and_mapping_uv(
(src, dst, eid), num_nodes)
(src, dst, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
# using edge map instead of message map because messages are in global
# message frame
var_out_map = _build_idx_map(recv_nodes, nbits)
......@@ -744,7 +758,7 @@ def _gen_reduce(graph, reduce_func, edge_tuples, recv_nodes):
else:
# gen degree bucketing schedule for UDF recv
db.gen_degree_bucketing_schedule(graph, rfunc, eid, dst, recv_nodes,
var_nf, var_msg, var_out)
var_dst_nf, var_msg, var_out)
return var_out
def _gen_send_reduce(
......@@ -930,12 +944,12 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef):
var_eid = var.IDX(eid)
if mfunc_is_list:
if eid.is_slice(0, graph.number_of_edges()):
if eid.is_slice(0, graph._number_of_edges()):
# full graph case
res = spmv.build_gidx_and_mapping_graph(graph)
else:
num_nodes = graph.number_of_nodes()
res = spmv.build_gidx_and_mapping_uv((u, v, eid), num_nodes)
res = spmv.build_gidx_and_mapping_uv(
(u, v, eid), graph._number_of_src_nodes(), graph._number_of_dst_nodes())
adj, edge_map, _ = res
# create a tmp message frame
tmp_mfr = FrameRef(frame_like(graph._edge_frame._frame, len(eid)))
......
"""Module for SPMV rules."""
from __future__ import absolute_import
from functools import partial
from ..base import DGLError
from .. import backend as F
from .. import utils
from .. import ndarray as nd
from ..graph_index import from_coo
from ..graph_index import GraphIndex
from ..heterograph_index import HeteroGraphIndex, create_bipartite_from_coo
from . import ir
from .ir import var
......@@ -127,8 +129,8 @@ def build_gidx_and_mapping_graph(graph):
Parameters
----------
graph : DGLGraph
The graph
graph : DGLGraph or DGLHeteroGraph
The homogeneous graph, or a bipartite view of the heterogeneous graph.
Returns
-------
......@@ -141,10 +143,17 @@ def build_gidx_and_mapping_graph(graph):
Number of ints needed to represent the graph
"""
gidx = graph._graph
return gidx.get_immutable_gidx, None, gidx.bits_needed()
if isinstance(gidx, GraphIndex):
return gidx.get_immutable_gidx, None, gidx.bits_needed()
elif isinstance(gidx, HeteroGraphIndex):
return (partial(gidx.get_bipartite, graph._current_etype_idx),
None,
gidx.bits_needed(graph._current_etype_idx))
else:
raise TypeError('unknown graph index type %s' % type(gidx))
def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
def build_gidx_and_mapping_uv(edge_tuples, num_src, num_dst):
"""Build immutable graph index and mapping using the given (u, v) edges
The matrix is of shape (len(reduce_nodes), n), where n is the number of
......@@ -155,8 +164,8 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
---------
edge_tuples : tuple of three utils.Index
A tuple of (u, v, eid)
num_nodes : int
The number of nodes.
num_src, num_dst : int
The number of source and destination nodes.
Returns
-------
......@@ -169,10 +178,10 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
Number of ints needed to represent the graph
"""
u, v, eid = edge_tuples
gidx = from_coo(num_nodes, u, v, None, True)
forward, backward = gidx.get_csr_shuffle_order()
gidx = create_bipartite_from_coo(num_src, num_dst, u, v)
forward, backward = gidx.get_csr_shuffle_order(0)
eid = eid.tousertensor()
nbits = gidx.bits_needed()
nbits = gidx.bits_needed(0)
forward_map = utils.to_nbits_int(eid[forward.tousertensor()], nbits)
backward_map = utils.to_nbits_int(eid[backward.tousertensor()], nbits)
forward_map = F.zerocopy_to_dgl_ndarray(forward_map)
......@@ -180,7 +189,7 @@ def build_gidx_and_mapping_uv(edge_tuples, num_nodes):
edge_map = utils.CtxCachedObject(
lambda ctx: (nd.array(forward_map, ctx=ctx),
nd.array(backward_map, ctx=ctx)))
return gidx.get_immutable_gidx, edge_map, nbits
return partial(gidx.get_bipartite, None), edge_map, nbits
def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
......@@ -212,6 +221,6 @@ def build_gidx_and_mapping_block(graph, block_id, edge_tuples=None):
eid = utils.toindex(eid)
else:
u, v, eid = edge_tuples
num_nodes = max(graph.layer_size(block_id), graph.layer_size(block_id + 1))
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_nodes)
num_src, num_dst = graph.layer_size(block_id), graph.layer_size(block_id + 1)
gidx, edge_map, nbits = build_gidx_and_mapping_uv((u, v, eid), num_src, num_dst)
return gidx, edge_map, nbits
This diff is collapsed.
......@@ -52,11 +52,6 @@ enum BoolFlag {
dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<dgl::runtime::NDArray>& vec);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
}
/*!
* \brief Copy a vector to an int64_t NDArray.
*
......
This diff is collapsed.
......@@ -7,12 +7,14 @@
#ifndef DGL_GRAPH_BIPARTITE_H_
#define DGL_GRAPH_BIPARTITE_H_
#include <dgl/graph_interface.h>
#include <dgl/base_heterograph.h>
#include <vector>
#include <string>
#include <dgl/lazy.h>
#include <dgl/array.h>
#include <utility>
#include <memory>
#include <string>
#include <vector>
#include "../c_api_common.h"
namespace dgl {
......@@ -32,6 +34,12 @@ class Bipartite : public BaseHeteroGraph {
/*! \brief edge group type */
static constexpr dgl_type_t kEType = 0;
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
uint64_t NumVertexTypes() const override {
return 2;
}
......@@ -140,14 +148,11 @@ class Bipartite : public BaseHeteroGraph {
int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids);
private:
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
/*! \return Return the in-edge CSR format. Create from other format if not exist. */
CSRPtr GetInCSR() const;
......@@ -158,6 +163,18 @@ class Bipartite : public BaseHeteroGraph {
/*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const;
/*! \return Return the in-edge CSR in the matrix form */
aten::CSRMatrix GetInCSRMatrix() const;
/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetOutCSRMatrix() const;
/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix() const;
private:
Bipartite(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo);
/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;
......
This diff is collapsed.
......@@ -249,8 +249,8 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
}
IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
CHECK(IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(IsValidIdArray(query)) << "Invalid query id array.";
CHECK(aten::IsValidIdArray(parent_vids)) << "Invalid parent id array.";
CHECK(aten::IsValidIdArray(query)) << "Invalid query id array.";
const auto parent_len = parent_vids->shape[0];
const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
......
This diff is collapsed.
This diff is collapsed.
......@@ -749,7 +749,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const bool add_self_loop,
const ValueType *probability) {
// process args
CHECK(IsValidIdArray(seed_nodes));
CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
......@@ -859,7 +859,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_nodes));
CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
......@@ -1017,7 +1017,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformEdgeSampling")
// process args
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(gptr) << "sampling isn't implemented in mutable graph";
CHECK(IsValidIdArray(seed_edges));
CHECK(aten::IsValidIdArray(seed_edges));
BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0];
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment