Unverified Commit 67cb7a43 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature] Deprecate multigraph (#1389)



* Deprecate multi-graph

* Handle heterograph and edge_ids

* lint

* Fix

* Remove multigraph in C++ end

* Fix lint

* Add some test and fix something

* Fix

* Fix

* upd

* Fix some test case

* Fix

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 3efb5d8e
......@@ -120,7 +120,7 @@ class VGRelation(VisionDataset):
img, bbox = self.img_transform(img, bbox)
# build the graph
g = dgl.DGLGraph(multigraph=True)
g = dgl.DGLGraph()
g.add_nodes(n_nodes)
adjmat = np.zeros((n_nodes, n_nodes))
predicate = []
......
......@@ -49,7 +49,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph(multigraph=True)
g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds],
'node_feat': spatial_feat[gi, roi_ind],
'node_class_pred': ids[gi, inds, 0],
......
......@@ -26,10 +26,10 @@ typedef std::shared_ptr<Graph> MutableGraphPtr;
class Graph: public GraphInterface {
public:
/*! \brief default constructor */
explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
Graph() {}
/*! \brief construct a graph from the coo format. */
Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes, bool multigraph = false);
Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes);
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
......@@ -44,7 +44,6 @@ class Graph: public GraphInterface {
all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_;
is_multigraph_ = other.is_multigraph_;
num_edges_ = other.num_edges_;
other.Clear();
}
......@@ -102,9 +101,7 @@ class Graph: public GraphInterface {
* \note not const since we have caches
* \return whether the graph is a multigraph
*/
bool IsMultigraph() const override {
return is_multigraph_;
}
bool IsMultigraph() const override;
/*!
* \return whether the graph is read-only
......@@ -350,14 +347,14 @@ class Graph: public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
/*! \brief Create an empty graph */
static MutableGraphPtr Create(bool multigraph = false) {
return std::make_shared<Graph>(multigraph);
static MutableGraphPtr Create() {
return std::make_shared<Graph>();
}
/*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO(
int64_t num_nodes, IdArray src_ids, IdArray dst_ids, bool multigraph = false) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes, multigraph);
int64_t num_nodes, IdArray src_ids, IdArray dst_ids) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes);
}
protected:
......@@ -383,12 +380,7 @@ class Graph: public GraphInterface {
/*! \brief read only flag */
bool read_only_ = false;
/*!
* \brief Whether if this is a multigraph.
*
* When a multiedge is added, this flag switches to true.
*/
bool is_multigraph_ = false;
/*! \brief number of edges */
uint64_t num_edges_ = 0;
};
......
......@@ -33,28 +33,24 @@ typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;
class CSR : public GraphInterface {
public:
// Create a csr graph that has the given number of verts and edges.
CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph);
CSR(int64_t num_vertices, int64_t num_edges);
// Create a csr graph whose memory is stored in the shared memory
// that has the given number of verts and edges.
CSR(const std::string &shared_mem_name,
int64_t num_vertices, int64_t num_edges, bool is_multigraph);
int64_t num_vertices, int64_t num_edges);
// Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);
// Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph);
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
// Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name);
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name);
void AddVertices(uint64_t num_vertices) override {
LOG(FATAL) << "CSR graph does not allow mutation.";
......@@ -262,9 +258,6 @@ class CSR : public GraphInterface {
// The data field stores edge ids.
aten::CSRMatrix adj_;
// whether the graph is a multi-graph
Lazy<bool> is_multigraph_;
// The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory.
std::string shared_mem_name_;
......@@ -274,7 +267,6 @@ class COO : public GraphInterface {
public:
// Create a coo graph that shares the given src and dst
COO(int64_t num_vertices, IdArray src, IdArray dst);
COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph);
// TODO(da): add constructor for creating COO from shared memory
......@@ -512,9 +504,6 @@ class COO : public GraphInterface {
// The internal COO adjacency matrix.
// The data field is empty
aten::COOMatrix adj_;
/*! \brief whether the graph is a multi-graph */
Lazy<bool> is_multigraph_;
};
/*!
......@@ -897,31 +886,18 @@ class ImmutableGraph: public GraphInterface {
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir);
size_t num_edges, const std::string &edge_dir);
/*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst);
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph);
/*!
* \brief Convert the given graph to an immutable graph.
*
......@@ -1025,8 +1001,7 @@ class ImmutableGraph: public GraphInterface {
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin,
bool is_multigraph): is_multigraph_(is_multigraph) {
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
// TODO(minjie): this should be changed to a device-agnostic implementation
// in the future
adj_.num_rows = num_vertices;
......
......@@ -118,7 +118,7 @@ class EdgeDataView(MutableMapping):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
def _to_csr(graph_data, edge_dir, multigraph):
def _to_csr(graph_data, edge_dir, multigraph=None):
try:
indptr = graph_data.indptr
indices = graph_data.indices
......@@ -128,7 +128,7 @@ def _to_csr(graph_data, edge_dir, multigraph):
csr = graph_data.tocsr()
return csr.indptr, csr.indices
else:
idx = create_graph_index(graph_data=graph_data, multigraph=multigraph, readonly=True)
idx = create_graph_index(graph_data=graph_data, readonly=True)
transpose = (edge_dir != 'in')
csr = idx.adjacency_matrix_scipy(transpose, 'csr')
return csr.indptr, csr.indices
......@@ -310,7 +310,8 @@ class SharedMemoryStoreServer(object):
graph_name : string
Define the name of the graph, so the client can use the name to access the graph.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True)
num_workers : int
The number of workers that will connect to the server.
port : int
......@@ -318,17 +319,21 @@ class SharedMemoryStoreServer(object):
"""
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None
if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
self._graph = DGLGraph(graph_data, readonly=True)
elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
self._graph = DGLGraph(graph_data, readonly=True)
else:
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
indptr, indices = _to_csr(graph_data, edge_dir)
graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices),
multigraph, edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True)
edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, readonly=True)
self._num_workers = num_workers
self._graph_name = graph_name
......@@ -354,7 +359,7 @@ class SharedMemoryStoreServer(object):
# if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \
self._graph.is_multigraph, edge_dir
True, edge_dir
# RPC command: initialize node embedding in the server.
def init_ndata(init, ndata_name, shape, dtype):
......@@ -479,7 +484,7 @@ class BaseGraphStore(DGLGraph):
"""
def __init__(self,
graph_data=None,
multigraph=False):
multigraph=None):
super(BaseGraphStore, self).__init__(graph_data, multigraph=multigraph, readonly=True)
@property
......@@ -555,12 +560,12 @@ class SharedMemoryDGLGraph(BaseGraphStore):
self._worker_id, self._num_workers = self.proxy.register(graph_name)
if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges, _, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges)
graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name),
num_nodes, num_edges, edge_dir, multigraph)
super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph)
num_nodes, num_edges, edge_dir)
super(SharedMemoryDGLGraph, self).__init__(graph_idx)
self._init_manager = InitializerManager()
# map all ndata and edata from the server.
......@@ -1054,7 +1059,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph=False, edge_dir='in', port=8000):
multigraph=None, edge_dir='in', port=8000):
"""Create the graph store server.
The server loads graph structure and node embeddings and edge embeddings.
......@@ -1085,7 +1090,8 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
num_workers : int
The number of workers that will connect to the server.
multigraph : bool, optional
Whether the graph would be a multigraph (default: False)
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True)
edge_dir : string
the edge direction for the graph structure. The supported option is
"in" and "out".
......@@ -1097,7 +1103,10 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
SharedMemoryStoreServer
The graph store server
"""
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, multigraph,
if multigraph is not None:
dgl_warning("multigraph is deprecated." \
"DGL treat all graphs as multigraph by default.")
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, None,
num_workers, port)
def create_graph_from_store(graph_name, store_type, port=8000):
......
......@@ -390,7 +390,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype)
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True, True)
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
......
......@@ -315,7 +315,7 @@ class DGLBaseGraph(object):
"""
return self._graph.successors(v).tousertensor()
def edge_id(self, u, v, force_multi=False):
def edge_id(self, u, v, force_multi=None, return_array=False):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`.
......@@ -326,15 +326,24 @@ class DGLBaseGraph(object):
v : int
The destination node ID.
force_multi : bool
If False, will return a single edge ID if the graph is a simple graph.
Deprecated (Will be deleted in the future).
If False, will return a single edge ID.
If True, will always return an array.
return_array : bool
If False, will return a single edge ID.
If True, will always return an array.
Returns
-------
int or tensor
The edge ID if force_multi == True and the graph is a simple graph.
The edge ID if return_array is False.
The edge ID array otherwise.
Notes
-----
If multiply edges exist between `u` and `v` and return_array is False,
the result is undefined.
Examples
--------
The following example uses PyTorch backend.
......@@ -351,14 +360,14 @@ class DGLBaseGraph(object):
For multigraphs:
>>> G = dgl.DGLGraph(multigraph=True)
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
Adding edges (0, 1), (0, 2), (0, 1), (0, 2), so edge ID 0 and 2 both
connect from 0 and 1, while edge ID 1 and 3 both connect from 0 and 2.
>>> G.add_edges([0, 0, 0, 0], [1, 2, 1, 2])
>>> G.edge_id(0, 1)
>>> G.edge_id(0, 1, return_array=True)
tensor([0, 2])
See Also
......@@ -366,9 +375,20 @@ class DGLBaseGraph(object):
edge_ids
"""
idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0]
if force_multi is not None:
dgl_warning("force_multi will be deprecated." \
"Please use return_array instead")
return_array = force_multi
def edge_ids(self, u, v, force_multi=False):
if return_array:
return idx.tousertensor()
else:
assert len(idx) == 1, "For return_array=False, there should be one and " \
"only one edge between u and v, but get {} edges. " \
"Please use return_array=True instead".format(len(idx))
return idx[0]
def edge_ids(self, u, v, force_multi=None, return_uv=False):
"""Return all edge IDs between source node array `u` and destination
node array `v`.
......@@ -379,21 +399,26 @@ class DGLBaseGraph(object):
v : list, tensor
The destination node ID array.
force_multi : bool
Deprecated (Will be deleted in the future).
Whether to always treat the graph as a multigraph.
return_uv : bool
Whether return e or (eu, ev, e)
Returns
-------
tensor, or (tensor, tensor, tensor)
If the graph is a simple graph and `force_multi` is False, return
a single edge ID array `e`. `e[i]` is the edge ID between `u[i]`
and `v[i]`.
If 'return_uv` is False, return a single edge ID array `e`.
`e[i]` is the edge ID between `u[i]` and `v[i]`.
Otherwise, return three arrays `(eu, ev, e)`. `e[i]` is the ID
of an edge between `eu[i]` and `ev[i]`. All edges between `u[i]`
and `v[i]` are returned.
Notes
-----
If the graph is a simple graph, `force_multi` is False, and no edge
If the graph is a simple graph, `return_uv` is False, and no edge
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
If the graph is a multi graph, `return_uv` is False, and multi edges
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Examples
......@@ -411,14 +436,14 @@ class DGLBaseGraph(object):
For multigraphs
>>> G = dgl.DGLGraph(multigraph=True)
>>> G = dgl.DGLGraph()
>>> G.add_nodes(4)
>>> G.add_edges([0, 0, 0], [1, 1, 2]) # (0, 1), (0, 1), (0, 2)
Get all edges between (0, 1), (0, 2), (0, 3). Note that there is no
edge between 0 and 3:
>>> G.edge_ids([0, 0, 0], [1, 2, 3])
>>> G.edge_ids([0, 0, 0], [1, 2, 3], return_uv=True)
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([0, 1, 2]))
See Also
......@@ -428,9 +453,17 @@ class DGLBaseGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(u, v)
if force_multi or self.is_multigraph:
if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \
"Please use return_uv instead")
return_uv = force_multi
if return_uv:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
"only one edge between each u and v, expect {} edges but get {}. " \
"Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
return eid.tousertensor()
def find_edges(self, eid):
......@@ -814,8 +847,9 @@ class DGLGraph(DGLBaseGraph):
edge_frame : FrameRef, optional
Edge feature storage.
multigraph : bool, optional
Whether the graph would be a multigraph. If none, the flag will be determined
by scanning the whole graph. (default: None)
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph. If none, the flag will be
set to True. (default: None)
readonly : bool, optional
Whether the graph structure is read-only (default: False).
......@@ -956,7 +990,10 @@ class DGLGraph(DGLBaseGraph):
if sort_csr:
gidx.sort_csr()
else:
gidx = graph_index.create_graph_index(graph_data, multigraph, readonly)
if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
gidx = graph_index.create_graph_index(graph_data, readonly)
if sort_csr:
gidx.sort_csr()
super(DGLGraph, self).__init__(gidx)
......@@ -1805,7 +1842,7 @@ class DGLGraph(DGLBaseGraph):
raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, spmat, multigraph=False):
def from_scipy_sparse_matrix(self, spmat, multigraph=None):
""" Convert from scipy sparse matrix.
Parameters
......@@ -1814,6 +1851,7 @@ class DGLGraph(DGLBaseGraph):
The graph's adjacency matrix
multigraph : bool, optional
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph. If the input scipy sparse matrix is CSR,
this argument is ignored.
......@@ -1828,7 +1866,11 @@ class DGLGraph(DGLBaseGraph):
>>> g.from_scipy_sparse_matrix(a)
"""
self.clear()
self._graph = graph_index.from_scipy_sparse_matrix(spmat, multigraph, self.is_readonly)
if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
self._graph = graph_index.from_scipy_sparse_matrix(spmat, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges())
......
......@@ -35,7 +35,6 @@ class GraphIndex(ObjectBase):
"""
def __new__(cls):
obj = ObjectBase.__new__(cls)
obj._multigraph = None # python-side cache of the flag
obj._readonly = None # python-side cache of the flag
obj._cache = {}
return obj
......@@ -43,28 +42,22 @@ class GraphIndex(ObjectBase):
def __getstate__(self):
src, dst, _ = self.edges()
n_nodes = self.number_of_nodes()
# TODO(minjie): should try to avoid calling is_multigraph
multigraph = self.is_multigraph()
readonly = self.is_readonly()
return n_nodes, multigraph, readonly, src, dst
return n_nodes, readonly, src, dst
def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, readonly, src_nodes, dst_nodes)
(number_of_nodes, readonly, src_nodes, dst_nodes)
"""
num_nodes, multigraph, readonly, src, dst = state
num_nodes, readonly, src, dst = state
self._cache = {}
self._multigraph = multigraph
self._readonly = readonly
if multigraph is None:
multigraph = BoolFlag.BOOL_UNKNOWN
self.__init_handle_by_constructor__(
_CAPI_DGLGraphCreate,
src.todgltensor(),
dst.todgltensor(),
int(multigraph),
int(num_nodes),
readonly)
......@@ -118,15 +111,14 @@ class GraphIndex(ObjectBase):
def is_multigraph(self):
"""Return whether the graph is a multigraph
The time cost will be O(E)
Returns
-------
bool
True if it is a multigraph, False otherwise.
"""
if self._multigraph is None:
self._multigraph = bool(_CAPI_DGLGraphIsMultigraph(self))
return self._multigraph
return bool(_CAPI_DGLGraphIsMultigraph(self))
def is_readonly(self):
"""Indicate whether the graph index is read-only.
......@@ -149,9 +141,9 @@ class GraphIndex(ObjectBase):
New readonly state of current graph index.
"""
# TODO(minjie): very ugly code, should fix this
n_nodes, multigraph, _, src, dst = self.__getstate__()
n_nodes, _, src, dst = self.__getstate__()
self.clear_cache()
state = (n_nodes, multigraph, readonly_state, src, dst)
state = (n_nodes, readonly_state, src, dst)
self.__setstate__(state)
def number_of_nodes(self):
......@@ -829,7 +821,8 @@ class GraphIndex(ObjectBase):
The nx graph
"""
src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
# xiangsx: Always treat graph as multigraph
ret = nx.MultiDiGraph()
ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=e)
......@@ -991,7 +984,7 @@ class SubgraphIndex(ObjectBase):
###############################################################
# Conversion functions
###############################################################
def from_coo(num_nodes, src, dst, is_multigraph, readonly):
def from_coo(num_nodes, src, dst, readonly):
"""Convert from coo arrays.
Parameters
......@@ -1002,8 +995,6 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly):
Src end nodes of the edges.
dst : Tensor
Dst end nodes of the edges.
is_multigraph : bool or None
True if the graph is a multigraph. None means determined by data.
readonly : bool
True if the returned graph is readonly.
......@@ -1014,25 +1005,19 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly):
"""
src = utils.toindex(src)
dst = utils.toindex(dst)
if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN
if readonly:
gidx = _CAPI_DGLGraphCreate(
src.todgltensor(),
dst.todgltensor(),
int(is_multigraph),
int(num_nodes),
readonly)
else:
if is_multigraph is BoolFlag.BOOL_UNKNOWN:
# TODO(minjie): better behavior in the future
is_multigraph = BoolFlag.BOOL_FALSE
gidx = _CAPI_DGLGraphCreateMutable(bool(is_multigraph))
gidx = _CAPI_DGLGraphCreateMutable()
gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst)
return gidx
def from_csr(indptr, indices, is_multigraph,
def from_csr(indptr, indices,
direction, shared_mem_name=""):
"""Load a graph from CSR arrays.
......@@ -1042,8 +1027,6 @@ def from_csr(indptr, indices, is_multigraph,
index pointer in the CSR format
indices : Tensor
column index array in the CSR format
is_multigraph : bool or None
True if the graph is a multigraph. None means determined by data.
direction : str
the edge direction. Either "in" or "out".
shared_mem_name : str
......@@ -1051,19 +1034,15 @@ def from_csr(indptr, indices, is_multigraph,
"""
indptr = utils.toindex(indptr)
indices = utils.toindex(indices)
if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN
gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(),
indices.todgltensor(),
shared_mem_name,
int(is_multigraph),
direction)
return gidx
def from_shared_mem_csr_matrix(shared_mem_name,
num_nodes, num_edges, edge_dir,
is_multigraph):
num_nodes, num_edges, edge_dir):
"""Load a graph from the shared memory in the CSR format.
Parameters
......@@ -1080,7 +1059,6 @@ def from_shared_mem_csr_matrix(shared_mem_name,
gidx = _CAPI_DGLGraphCSRCreateMMap(
shared_mem_name,
int(num_nodes), int(num_edges),
is_multigraph,
edge_dir)
return gidx
......@@ -1109,8 +1087,6 @@ def from_networkx(nx_graph, readonly):
# to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed()
is_multigraph = isinstance(nx_graph, nx.MultiDiGraph)
num_nodes = nx_graph.number_of_nodes()
# nx_graph.edges(data=True) returns src, dst, attr_dict
......@@ -1137,17 +1113,14 @@ def from_networkx(nx_graph, readonly):
# We store edge Ids as an edge attribute.
src = utils.toindex(src)
dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, is_multigraph, readonly)
return from_coo(num_nodes, src, dst, readonly)
def from_scipy_sparse_matrix(adj, multigraph, readonly):
def from_scipy_sparse_matrix(adj, readonly):
"""Convert from scipy sparse matrix.
Parameters
----------
adj : scipy sparse matrix
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool
True if the returned graph is readonly.
......@@ -1159,12 +1132,12 @@ def from_scipy_sparse_matrix(adj, multigraph, readonly):
if adj.getformat() != 'csr' or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, multigraph, readonly)
return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)
else:
# If the input matrix is csr, it's guaranteed to be a simple graph.
return from_csr(adj.indptr, adj.indices, False, "out")
# If the input matrix is csr, we still treat it as multigraph.
return from_csr(adj.indptr, adj.indices, "out")
def from_edge_list(elist, is_multigraph, readonly):
def from_edge_list(elist, readonly):
"""Convert from an edge list.
Parameters
......@@ -1184,7 +1157,7 @@ def from_edge_list(elist, is_multigraph, readonly):
min_nodes = min(src.min(), dst.min())
if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly)
return from_coo(num_nodes, src_ids, dst_ids, readonly)
def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
......@@ -1274,16 +1247,13 @@ def disjoint_partition(graph, num_or_size_splits):
int(num_or_size_splits))
return rst
def create_graph_index(graph_data, multigraph, readonly):
def create_graph_index(graph_data, readonly):
"""Create a graph index object.
Parameters
----------
graph_data : graph data
Data to initialize graph. Same as networkx's semantics.
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool
Whether the graph structure is read-only.
"""
......@@ -1294,15 +1264,13 @@ def create_graph_index(graph_data, multigraph, readonly):
if graph_data is None:
if readonly:
raise Exception("can't create an empty immutable graph")
if multigraph is None:
multigraph = False
return _CAPI_DGLGraphCreateMutable(multigraph)
return _CAPI_DGLGraphCreateMutable()
elif isinstance(graph_data, (list, tuple)):
# edge list
return from_edge_list(graph_data, multigraph, readonly)
return from_edge_list(graph_data, readonly)
elif isinstance(graph_data, scipy.sparse.spmatrix):
# scipy format
return from_scipy_sparse_matrix(graph_data, multigraph, readonly)
return from_scipy_sparse_matrix(graph_data, readonly)
else:
# networkx - any format
try:
......
......@@ -266,8 +266,6 @@ class DGLHeteroGraph(object):
frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame)
self._is_multigraph = None
def __getstate__(self):
return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames
......@@ -1075,10 +1073,7 @@ class DGLHeteroGraph(object):
bool
True if the graph is a multigraph, False otherwise.
"""
if self._is_multigraph is None:
return self._graph.is_multigraph()
else:
return self._is_multigraph
@property
def is_readonly(self):
......@@ -1295,7 +1290,7 @@ class DGLHeteroGraph(object):
"""
return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
def edge_id(self, u, v, force_multi=False, etype=None):
def edge_id(self, u, v, force_multi=None, return_array=False, etype=None):
"""Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`, with the specified edge type
......@@ -1306,7 +1301,11 @@ class DGLHeteroGraph(object):
v : int
The node ID of destination type.
force_multi : bool, optional
If False, will return a single edge ID if the graph is a simple graph.
Deprecated (Will be deleted in the future).
If False, will return a single edge ID.
If True, will always return an array. (Default: False)
return_array : bool, optional
If False, will return a single edge ID.
If True, will always return an array. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
......@@ -1315,9 +1314,14 @@ class DGLHeteroGraph(object):
Returns
-------
int or tensor
The edge ID if ``force_multi == True`` and the graph is a simple graph.
The edge ID if ``return_array == False``.
The edge ID array otherwise.
Notes
-----
If multiply edges exist between `u` and `v` and return_array is False,
the result is undefined.
Examples
--------
The following example uses PyTorch backend.
......@@ -1332,7 +1336,7 @@ class DGLHeteroGraph(object):
>>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game'))
2
>>> g.edge_id(1, 2, force_multi=True, etype=('user', 'follows', 'user'))
>>> g.edge_id(1, 2, return_array=True, etype=('user', 'follows', 'user'))
tensor([1, 2])
See Also
......@@ -1340,9 +1344,20 @@ class DGLHeteroGraph(object):
edge_ids
"""
idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0]
if force_multi is not None:
dgl_warning("force_multi will be deprecated." \
"Please use return_array instead")
return_array = force_multi
if return_array:
return idx.tousertensor()
else:
assert len(idx) == 1, "For return_array=False, there should be one and " \
"only one edge between u and v, but get {} edges. " \
"Please use return_array=True instead".format(len(idx))
return idx[0]
def edge_ids(self, u, v, force_multi=False, etype=None):
def edge_ids(self, u, v, force_multi=None, return_uv=False, etype=None):
"""Return all edge IDs between source node array `u` and destination
node array `v` with the specified edge type.
......@@ -1353,8 +1368,11 @@ class DGLHeteroGraph(object):
v : list, tensor
The node ID array of destination type.
force_multi : bool, optional
Deprecated (Will be deleted in the future).
Whether to always treat the graph as a multigraph. See the
"Returns" for their effects. (Default: False)
return_uv : bool
See the "Returns" for their effects. (Default: False)
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
......@@ -1363,9 +1381,8 @@ class DGLHeteroGraph(object):
-------
tensor, or (tensor, tensor, tensor)
* If the graph is a simple graph and ``force_multi=False``, return
a single edge ID array ``e``. ``e[i]`` is the edge ID between ``u[i]``
and ``v[i]``.
* If ``return_uv=False``, return a single edge ID array ``e``.
``e[i]`` is the edge ID between ``u[i]`` and ``v[i]``.
* Otherwise, return three arrays ``(eu, ev, e)``. ``e[i]`` is the ID
of an edge between ``eu[i]`` and ``ev[i]``. All edges between ``u[i]``
......@@ -1373,10 +1390,13 @@ class DGLHeteroGraph(object):
Notes
-----
If the graph is a simple graph, ``force_multi=False``, and no edge
If the graph is a simple graph, ``return_uv=False``, and no edge
exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined
and an empty tensor is returned.
If the graph is a multi graph, ``return_uv=False``, and multi edges
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Examples
--------
The following example uses PyTorch backend.
......@@ -1393,7 +1413,7 @@ class DGLHeteroGraph(object):
tensor([], dtype=torch.int64)
>>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game'))
tensor([2])
>>> g.edge_ids([1], [2], force_multi=True, etype=('user', 'follows', 'user'))
>>> g.edge_ids([1], [2], return_uv=True, etype=('user', 'follows', 'user'))
(tensor([1, 1]), tensor([2, 2]), tensor([1, 2]))
See Also
......@@ -1403,9 +1423,17 @@ class DGLHeteroGraph(object):
u = utils.toindex(u)
v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
if force_multi or self._graph.is_multigraph():
if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \
"Please use return_uv instead")
return_uv = force_multi
if return_uv:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else:
assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
"only one edge between each u and v, expect {} edges but get {}. " \
"Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
return eid.tousertensor()
def find_edges(self, eid, etype=None):
......@@ -2023,7 +2051,7 @@ class DGLHeteroGraph(object):
induced_etypes.append(self.etypes[i])
edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True, True)
metagraph = graph_index.from_edge_list(meta_edges, True)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
......@@ -2097,7 +2125,7 @@ class DGLHeteroGraph(object):
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True)
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
......@@ -3787,7 +3815,8 @@ class DGLHeteroGraph(object):
src, dst = self.edges()
src = F.asnumpy(src)
dst = F.asnumpy(dst)
nx_graph = nx.MultiDiGraph() if self.is_multigraph else nx.DiGraph()
# xiangsx: Always treat graph as multigraph
nx_graph = nx.MultiDiGraph()
nx_graph.add_nodes_from(range(self.number_of_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid)
......
......@@ -216,6 +216,7 @@ class HeteroGraphIndex(ObjectBase):
def is_multigraph(self):
"""Return whether the graph is a multigraph
The time cost will be O(E)
Returns
-------
......
......@@ -241,7 +241,7 @@ def khop_graph(g, k):
col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future.
return DGLGraph(from_coo(n, row, col, True, True))
return DGLGraph(from_coo(n, row, col, True))
def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph
......@@ -310,7 +310,7 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.],
[3.]])
"""
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed = DGLGraph()
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0])
......@@ -357,6 +357,10 @@ def to_bidirected(g, readonly=True):
readonly : bool, default to be True
Whether the returned bidirected graph is readonly or not.
Notes
-----
Please make sure g is a single graph, otherwise the return value is undefined.
Returns
-------
DGLGraph
......
......@@ -14,8 +14,7 @@
namespace dgl {
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes,
bool multigraph): is_multigraph_(multigraph) {
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
CHECK(aten::IsValidIdArray(src_ids));
CHECK(aten::IsValidIdArray(dst_ids));
this->AddVertices(num_nodes);
......@@ -42,6 +41,33 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes,
}
}
bool Graph::IsMultigraph() const {
if (num_edges_ <= 1) {
return false;
}
typedef std::pair<int64_t, int64_t> Pair;
std::vector<Pair> pairs;
pairs.reserve(num_edges_);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
}
// sort according to src and dst ids
std::sort(pairs.begin(), pairs.end(),
[] (const Pair& t1, const Pair& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
});
for (uint64_t eid = 0; eid < num_edges_-1; ++eid) {
// As src and dst are all sorted, we only need to compare i and i+1
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1]))
return true;
}
return false;
}
void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices);
......@@ -447,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
oldv2newv[vid_data[i]] = i;
}
Subgraph rst;
rst.graph = std::make_shared<Graph>(IsMultigraph());
rst.graph = std::make_shared<Graph>();
rst.induced_vertices = vids;
rst.graph->AddVertices(len);
for (int64_t i = 0; i < len; ++i) {
......@@ -486,7 +512,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
nodes.push_back(dst_id);
}
rst.graph = std::make_shared<Graph>(IsMultigraph());
rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids;
rst.graph->AddVertices(nodes.size());
......@@ -500,7 +526,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
} else {
rst.graph = std::make_shared<Graph>(IsMultigraph());
rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids;
rst.graph->AddVertices(NumVertices());
......
......@@ -23,8 +23,7 @@ namespace dgl {
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
bool multigraph = args[0];
*rv = GraphRef(Graph::Create(multigraph));
*rv = GraphRef(Graph::Create());
});
......@@ -32,18 +31,12 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray src_ids = args[0];
const IdArray dst_ids = args[1];
const int multigraph = args[2];
const int64_t num_nodes = args[3];
const bool readonly = args[4];
const int64_t num_nodes = args[2];
const bool readonly = args[3];
if (readonly) {
if (multigraph == kBoolUnknown) {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
}
} else {
CHECK_NE(multigraph, kBoolUnknown);
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));
}
});
......@@ -52,8 +45,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
const IdArray indptr = args[0];
const IdArray indices = args[1];
const std::string shared_mem_name = args[2];
const int multigraph = args[3];
const std::string edge_dir = args[4];
const std::string edge_dir = args[3];
IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
......@@ -61,20 +53,10 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
if (shared_mem_name.empty()) {
if (multigraph == kBoolUnknown) {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, multigraph, edge_dir));
}
} else {
if (multigraph == kBoolUnknown) {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, edge_dir, shared_mem_name));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir, shared_mem_name));
}
}
});
......@@ -83,11 +65,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
const std::string shared_mem_name = args[0];
const int64_t num_vertices = args[1];
const int64_t num_edges = args[2];
const bool multigraph = args[3];
const std::string edge_dir = args[4];
// TODO(minjie): how to know multigraph
const std::string edge_dir = args[3];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, multigraph, edge_dir));
shared_mem_name, num_vertices, num_edges, edge_dir));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
......
......@@ -325,7 +325,7 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
indptr[src+1] = indices.size();
}
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(),
indptr.begin(), indices.begin(), RangeIter(0), false));
indptr.begin(), indices.begin(), RangeIter(0)));
return std::make_shared<ImmutableGraph>(csr);
}
......@@ -397,7 +397,7 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = aten::VecToIdArray(dsts);
return ImmutableGraph::CreateFromCOO(
g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph());
g->NumVertices(), srcs_array, dsts_array);
}
HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) {
......
......@@ -187,14 +187,12 @@ HeteroGraph::HeteroGraph(
}
bool HeteroGraph::IsMultigraph() const {
return const_cast<HeteroGraph*>(this)->is_multigraph_.Get([this] () {
for (const auto &hg : relation_graphs_) {
if (hg->IsMultigraph()) {
return true;
}
}
return false;
});
}
BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
......
......@@ -214,9 +214,6 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_;
/*! \brief True if the graph is a multigraph */
Lazy<bool> is_multigraph_;
};
} // namespace dgl
......
......@@ -55,8 +55,7 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
//
//////////////////////////////////////////////////////////
CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CSR::CSR(int64_t num_vertices, int64_t num_edges) {
CHECK(!(num_vertices == 0 && num_edges != 0));
adj_ = aten::CSRMatrix{num_vertices, num_vertices,
aten::NewIdArray(num_vertices + 1),
......@@ -75,17 +74,6 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
adj_.sorted = false;
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
adj_.sorted = false;
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr));
......@@ -105,29 +93,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
adj_.sorted = false;
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name): is_multigraph_(is_multigraph),
shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0];
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays
adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices);
adj_.data.CopyFrom(edge_ids);
adj_.sorted = false;
}
CSR::CSR(const std::string &shared_mem_name,
int64_t num_verts, int64_t num_edges, bool is_multigraph)
: is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) {
CHECK(!(num_verts == 0 && num_edges != 0));
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
......@@ -137,10 +104,7 @@ CSR::CSR(const std::string &shared_mem_name,
}
bool CSR::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag.
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
return aten::CSRHasDuplicate(adj_);
});
}
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
......@@ -233,7 +197,6 @@ CSR CSR::CopyTo(const DLContext& ctx) const {
CSR ret(adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
......@@ -255,7 +218,6 @@ CSR CSR::AsNumBits(uint8_t bits) const {
CSR ret(aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
......@@ -301,19 +263,8 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
}
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
}
bool COO::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag.
return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
return aten::COOHasDuplicate(adj_);
});
}
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
......@@ -347,12 +298,12 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Relabel_({new_src, new_dst});
const auto new_nnodes = induced_nodes->shape[0];
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph()));
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
} else {
IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph()));
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
}
Subgraph subg;
subg.graph = subcoo;
......@@ -373,7 +324,6 @@ COO COO::CopyTo(const DLContext& ctx) const {
COO ret(NumVertices(),
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
......@@ -390,7 +340,6 @@ COO COO::AsNumBits(uint8_t bits) const {
COO ret(NumVertices(),
aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
......@@ -518,40 +467,11 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
CSRPtr csr(new CSR(indptr, indices, edge_ids,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
......@@ -565,10 +485,8 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
size_t num_edges, const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
......@@ -585,12 +503,6 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
COOPtr coo(new COO(num_vertices, src, dst, multigraph));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
if (ig) {
......
......@@ -355,7 +355,7 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<std::pair<dgl_id_t, int> > *sub_vers,
std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) {
int64_t num_edges, int num_hops) {
NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size();
nf->node_mapping = aten::NewIdArray(num_vertices);
......@@ -368,9 +368,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);
// Construct sub_csr_graph
// TODO(minjie): is nodeflow a multigraph?
auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges, is_multigraph));
// Construct sub_csr_graph, we treat nodeflow as multigraph by default
auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges));
dgl_id_t* indptr_out = static_cast<dgl_id_t*>(subg_csr->indptr()->data);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>(subg_csr->indices()->data);
dgl_id_t* eid_out = static_cast<dgl_id_t*>(subg_csr->edge_ids()->data);
......@@ -592,7 +591,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
}
return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos,
edge_type, num_edges, num_hops, graph->IsMultigraph());
edge_type, num_edges, num_hops);
}
} // namespace
......@@ -1090,7 +1089,6 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
int64_t num_tot_nodes = gptr_->NumVertices();
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
bool is_multigraph = gptr_->IsMultigraph();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
......@@ -1223,7 +1221,7 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
NegSubgraph neg_subg;
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph));
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid;
......@@ -1374,7 +1372,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
NegSubgraph neg_subg;
// We sample negative vertices without replacement.
// There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, gptr_->IsMultigraph()));
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid;
......
......@@ -69,16 +69,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
}
COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray src, IdArray dst, bool is_multigraph)
: BaseHeteroGraph(metagraph),
is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
}
COO(GraphPtr metagraph, const aten::COOMatrix& coo)
: BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always
......@@ -142,7 +132,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
......@@ -155,14 +144,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.num_rows, adj_.num_cols,
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
bool IsMultigraph() const override {
return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
return aten::COOHasDuplicate(adj_);
});
}
bool IsReadonly() const override {
......@@ -422,9 +408,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
/*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_;
/*! \brief multi-graph flag */
Lazy<bool> is_multigraph_;
};
//////////////////////////////////////////////////////////
......@@ -448,17 +431,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: BaseHeteroGraph(metagraph), is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) {
}
......@@ -519,7 +491,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
......@@ -534,15 +505,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret;
}
}
bool IsMultigraph() const override {
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
return aten::CSRHasDuplicate(adj_);
});
}
bool IsReadonly() const override {
......@@ -775,9 +743,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
/*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_;
/*! \brief multi-graph flag */
Lazy<bool> is_multigraph_;
};
//////////////////////////////////////////////////////////
......@@ -1302,20 +1267,20 @@ GraphPtr UnitGraph::AsImmutableGraph() const {
dgl::COOPtr coo_ptr = nullptr;
if (in_csr_) {
aten::CSRMatrix csc = GetCSCMatrix(0);
in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data, true));
in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
}
if (out_csr_) {
aten::CSRMatrix csr = GetCSRMatrix(0);
out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data, true));
out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
}
if (coo_) {
aten::COOMatrix coo = GetCOOMatrix(0);
if (!COOHasData(coo)) {
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col, true));
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
} else {
IdArray new_src = Scatter(coo.row, coo.data);
IdArray new_dst = Scatter(coo.col, coo.data);
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst, true));
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
}
}
return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
......
......@@ -177,7 +177,7 @@ def test_nx_conversion():
n3 = F.randn((5, 4))
e1 = F.randn((4, 5))
e2 = F.randn((4, 7))
g = DGLGraph(multigraph=True)
g = DGLGraph()
g.add_nodes(5)
g.add_edges([0,1,3,4], [2,4,0,3])
g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
......@@ -225,7 +225,7 @@ def test_nx_conversion():
for _, _, attr in nxg.edges(data=True):
attr.pop('id')
# test with a new graph
g = DGLGraph(multigraph=True)
g = DGLGraph()
g.from_networkx(nxg, node_attrs=['n1'], edge_attrs=['e1'])
# check graph size
assert g.number_of_nodes() == 7
......@@ -512,7 +512,7 @@ def test_pull_0deg():
assert F.allclose(new[1], old[1])
def test_send_multigraph():
g = DGLGraph(multigraph=True)
g = DGLGraph()
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment