"...git@developer.sourcefind.cn:OpenDAS/pytorch-encoding.git" did not exist on "c2cb2aab69d5d276fbcb847fb8277c1a52947661"
Unverified Commit 67cb7a43 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Feature] Deprecate multigraph (#1389)



* Deprecate multi-graph

* Handle heterograph and edge_ids

* lint

* Fix

* Remove multigraph in C++ end

* Fix lint

* Add some test and fix something

* Fix

* Fix

* upd

* Fix some test case

* Fix

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 3efb5d8e
...@@ -120,7 +120,7 @@ class VGRelation(VisionDataset): ...@@ -120,7 +120,7 @@ class VGRelation(VisionDataset):
img, bbox = self.img_transform(img, bbox) img, bbox = self.img_transform(img, bbox)
# build the graph # build the graph
g = dgl.DGLGraph(multigraph=True) g = dgl.DGLGraph()
g.add_nodes(n_nodes) g.add_nodes(n_nodes)
adjmat = np.zeros((n_nodes, n_nodes)) adjmat = np.zeros((n_nodes, n_nodes))
predicate = [] predicate = []
......
...@@ -49,7 +49,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind, ...@@ -49,7 +49,7 @@ def build_graph_train(g_slice, gt_bbox, img, ids, scores, bbox, feat_ind,
n_nodes = len(inds) n_nodes = len(inds)
roi_ind = feat_ind[gi, inds].squeeze(axis=1) roi_ind = feat_ind[gi, inds].squeeze(axis=1)
g_pred = dgl.DGLGraph(multigraph=True) g_pred = dgl.DGLGraph()
g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds], g_pred.add_nodes(n_nodes, {'pred_bbox': bbox[gi, inds],
'node_feat': spatial_feat[gi, roi_ind], 'node_feat': spatial_feat[gi, roi_ind],
'node_class_pred': ids[gi, inds, 0], 'node_class_pred': ids[gi, inds, 0],
......
...@@ -26,10 +26,10 @@ typedef std::shared_ptr<Graph> MutableGraphPtr; ...@@ -26,10 +26,10 @@ typedef std::shared_ptr<Graph> MutableGraphPtr;
class Graph: public GraphInterface { class Graph: public GraphInterface {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {} Graph() {}
/*! \brief construct a graph from the coo format. */ /*! \brief construct a graph from the coo format. */
Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes, bool multigraph = false); Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes);
/*! \brief default copy constructor */ /*! \brief default copy constructor */
Graph(const Graph& other) = default; Graph(const Graph& other) = default;
...@@ -44,7 +44,6 @@ class Graph: public GraphInterface { ...@@ -44,7 +44,6 @@ class Graph: public GraphInterface {
all_edges_src_ = other.all_edges_src_; all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_; all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_; read_only_ = other.read_only_;
is_multigraph_ = other.is_multigraph_;
num_edges_ = other.num_edges_; num_edges_ = other.num_edges_;
other.Clear(); other.Clear();
} }
...@@ -102,9 +101,7 @@ class Graph: public GraphInterface { ...@@ -102,9 +101,7 @@ class Graph: public GraphInterface {
* \note not const since we have caches * \note not const since we have caches
* \return whether the graph is a multigraph * \return whether the graph is a multigraph
*/ */
bool IsMultigraph() const override { bool IsMultigraph() const override;
return is_multigraph_;
}
/*! /*!
* \return whether the graph is read-only * \return whether the graph is read-only
...@@ -350,14 +347,14 @@ class Graph: public GraphInterface { ...@@ -350,14 +347,14 @@ class Graph: public GraphInterface {
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
/*! \brief Create an empty graph */ /*! \brief Create an empty graph */
static MutableGraphPtr Create(bool multigraph = false) { static MutableGraphPtr Create() {
return std::make_shared<Graph>(multigraph); return std::make_shared<Graph>();
} }
/*! \brief Create from coo */ /*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO( static MutableGraphPtr CreateFromCOO(
int64_t num_nodes, IdArray src_ids, IdArray dst_ids, bool multigraph = false) { int64_t num_nodes, IdArray src_ids, IdArray dst_ids) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes, multigraph); return std::make_shared<Graph>(src_ids, dst_ids, num_nodes);
} }
protected: protected:
...@@ -383,12 +380,7 @@ class Graph: public GraphInterface { ...@@ -383,12 +380,7 @@ class Graph: public GraphInterface {
/*! \brief read only flag */ /*! \brief read only flag */
bool read_only_ = false; bool read_only_ = false;
/*!
* \brief Whether if this is a multigraph.
*
* When a multiedge is added, this flag switches to true.
*/
bool is_multigraph_ = false;
/*! \brief number of edges */ /*! \brief number of edges */
uint64_t num_edges_ = 0; uint64_t num_edges_ = 0;
}; };
......
...@@ -33,28 +33,24 @@ typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr; ...@@ -33,28 +33,24 @@ typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;
class CSR : public GraphInterface { class CSR : public GraphInterface {
public: public:
// Create a csr graph that has the given number of verts and edges. // Create a csr graph that has the given number of verts and edges.
CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph); CSR(int64_t num_vertices, int64_t num_edges);
// Create a csr graph whose memory is stored in the shared memory // Create a csr graph whose memory is stored in the shared memory
// that has the given number of verts and edges. // that has the given number of verts and edges.
CSR(const std::string &shared_mem_name, CSR(const std::string &shared_mem_name,
int64_t num_vertices, int64_t num_edges, bool is_multigraph); int64_t num_vertices, int64_t num_edges);
// Create a csr graph that shares the given indptr and indices. // Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids); CSR(IdArray indptr, IdArray indices, IdArray edge_ids);
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph);
// Create a csr graph by data iterator // Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter> template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges, CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin, IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin);
bool is_multigraph);
// Create a csr graph whose memory is stored in the shared memory // Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies. // and the structure is given by the indptr and indcies.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name); const std::string &shared_mem_name);
CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name);
void AddVertices(uint64_t num_vertices) override { void AddVertices(uint64_t num_vertices) override {
LOG(FATAL) << "CSR graph does not allow mutation."; LOG(FATAL) << "CSR graph does not allow mutation.";
...@@ -262,9 +258,6 @@ class CSR : public GraphInterface { ...@@ -262,9 +258,6 @@ class CSR : public GraphInterface {
// The data field stores edge ids. // The data field stores edge ids.
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
// whether the graph is a multi-graph
Lazy<bool> is_multigraph_;
// The name of the shared memory to store data. // The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory. // If it's empty, data isn't stored in shared memory.
std::string shared_mem_name_; std::string shared_mem_name_;
...@@ -274,7 +267,6 @@ class COO : public GraphInterface { ...@@ -274,7 +267,6 @@ class COO : public GraphInterface {
public: public:
// Create a coo graph that shares the given src and dst // Create a coo graph that shares the given src and dst
COO(int64_t num_vertices, IdArray src, IdArray dst); COO(int64_t num_vertices, IdArray src, IdArray dst);
COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph);
// TODO(da): add constructor for creating COO from shared memory // TODO(da): add constructor for creating COO from shared memory
...@@ -512,9 +504,6 @@ class COO : public GraphInterface { ...@@ -512,9 +504,6 @@ class COO : public GraphInterface {
// The internal COO adjacency matrix. // The internal COO adjacency matrix.
// The data field is empty // The data field is empty
aten::COOMatrix adj_; aten::COOMatrix adj_;
/*! \brief whether the graph is a multi-graph */
Lazy<bool> is_multigraph_;
}; };
/*! /*!
...@@ -897,31 +886,18 @@ class ImmutableGraph: public GraphInterface { ...@@ -897,31 +886,18 @@ class ImmutableGraph: public GraphInterface {
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir); IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &shared_mem_name); const std::string &edge_dir, const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices, const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph, size_t num_edges, const std::string &edge_dir);
const std::string &edge_dir);
/*! \brief Create an immutable graph from COO. */ /*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO( static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst); int64_t num_vertices, IdArray src, IdArray dst);
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph);
/*! /*!
* \brief Convert the given graph to an immutable graph. * \brief Convert the given graph to an immutable graph.
* *
...@@ -1025,8 +1001,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -1025,8 +1001,7 @@ class ImmutableGraph: public GraphInterface {
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter> template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR::CSR(int64_t num_vertices, int64_t num_edges, CSR::CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin, IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin) {
bool is_multigraph): is_multigraph_(is_multigraph) {
// TODO(minjie): this should be changed to a device-agnostic implementation // TODO(minjie): this should be changed to a device-agnostic implementation
// in the future // in the future
adj_.num_rows = num_vertices; adj_.num_rows = num_vertices;
......
...@@ -118,7 +118,7 @@ class EdgeDataView(MutableMapping): ...@@ -118,7 +118,7 @@ class EdgeDataView(MutableMapping):
data = self._graph.get_e_repr(self._edges) data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame}) return repr({key : data[key] for key in self._graph._edge_frame})
def _to_csr(graph_data, edge_dir, multigraph): def _to_csr(graph_data, edge_dir, multigraph=None):
try: try:
indptr = graph_data.indptr indptr = graph_data.indptr
indices = graph_data.indices indices = graph_data.indices
...@@ -128,7 +128,7 @@ def _to_csr(graph_data, edge_dir, multigraph): ...@@ -128,7 +128,7 @@ def _to_csr(graph_data, edge_dir, multigraph):
csr = graph_data.tocsr() csr = graph_data.tocsr()
return csr.indptr, csr.indices return csr.indptr, csr.indices
else: else:
idx = create_graph_index(graph_data=graph_data, multigraph=multigraph, readonly=True) idx = create_graph_index(graph_data=graph_data, readonly=True)
transpose = (edge_dir != 'in') transpose = (edge_dir != 'in')
csr = idx.adjacency_matrix_scipy(transpose, 'csr') csr = idx.adjacency_matrix_scipy(transpose, 'csr')
return csr.indptr, csr.indices return csr.indptr, csr.indices
...@@ -310,7 +310,8 @@ class SharedMemoryStoreServer(object): ...@@ -310,7 +310,8 @@ class SharedMemoryStoreServer(object):
graph_name : string graph_name : string
Define the name of the graph, so the client can use the name to access the graph. Define the name of the graph, so the client can use the name to access the graph.
multigraph : bool, optional multigraph : bool, optional
Whether the graph would be a multigraph (default: False) Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True)
num_workers : int num_workers : int
The number of workers that will connect to the server. The number of workers that will connect to the server.
port : int port : int
...@@ -318,17 +319,21 @@ class SharedMemoryStoreServer(object): ...@@ -318,17 +319,21 @@ class SharedMemoryStoreServer(object):
""" """
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port): def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None self.server = None
if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
if isinstance(graph_data, GraphIndex): if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name)) graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True) self._graph = DGLGraph(graph_data, readonly=True)
elif isinstance(graph_data, DGLGraph): elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name)) graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True) self._graph = DGLGraph(graph_data, readonly=True)
else: else:
indptr, indices = _to_csr(graph_data, edge_dir, multigraph) indptr, indices = _to_csr(graph_data, edge_dir)
graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices), graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices),
multigraph, edge_dir, _get_graph_path(graph_name)) edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True) self._graph = DGLGraph(graph_idx, readonly=True)
self._num_workers = num_workers self._num_workers = num_workers
self._graph_name = graph_name self._graph_name = graph_name
...@@ -354,7 +359,7 @@ class SharedMemoryStoreServer(object): ...@@ -354,7 +359,7 @@ class SharedMemoryStoreServer(object):
# if the integers are larger than 2^31, xmlrpc can't handle them. # if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients. # we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \ return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \
self._graph.is_multigraph, edge_dir True, edge_dir
# RPC command: initialize node embedding in the server. # RPC command: initialize node embedding in the server.
def init_ndata(init, ndata_name, shape, dtype): def init_ndata(init, ndata_name, shape, dtype):
...@@ -479,7 +484,7 @@ class BaseGraphStore(DGLGraph): ...@@ -479,7 +484,7 @@ class BaseGraphStore(DGLGraph):
""" """
def __init__(self, def __init__(self,
graph_data=None, graph_data=None,
multigraph=False): multigraph=None):
super(BaseGraphStore, self).__init__(graph_data, multigraph=multigraph, readonly=True) super(BaseGraphStore, self).__init__(graph_data, multigraph=multigraph, readonly=True)
@property @property
...@@ -555,12 +560,12 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -555,12 +560,12 @@ class SharedMemoryDGLGraph(BaseGraphStore):
self._worker_id, self._num_workers = self.proxy.register(graph_name) self._worker_id, self._num_workers = self.proxy.register(graph_name)
if self._worker_id < 0: if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store') raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name) num_nodes, num_edges, _, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges) num_nodes, num_edges = int(num_nodes), int(num_edges)
graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name), graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name),
num_nodes, num_edges, edge_dir, multigraph) num_nodes, num_edges, edge_dir)
super(SharedMemoryDGLGraph, self).__init__(graph_idx, multigraph=multigraph) super(SharedMemoryDGLGraph, self).__init__(graph_idx)
self._init_manager = InitializerManager() self._init_manager = InitializerManager()
# map all ndata and edata from the server. # map all ndata and edata from the server.
...@@ -1054,7 +1059,7 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -1054,7 +1059,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
def create_graph_store_server(graph_data, graph_name, store_type, num_workers, def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph=False, edge_dir='in', port=8000): multigraph=None, edge_dir='in', port=8000):
"""Create the graph store server. """Create the graph store server.
The server loads graph structure and node embeddings and edge embeddings. The server loads graph structure and node embeddings and edge embeddings.
...@@ -1085,7 +1090,8 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers, ...@@ -1085,7 +1090,8 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
num_workers : int num_workers : int
The number of workers that will connect to the server. The number of workers that will connect to the server.
multigraph : bool, optional multigraph : bool, optional
Whether the graph would be a multigraph (default: False) Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True)
edge_dir : string edge_dir : string
the edge direction for the graph structure. The supported option is the edge direction for the graph structure. The supported option is
"in" and "out". "in" and "out".
...@@ -1097,7 +1103,10 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers, ...@@ -1097,7 +1103,10 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
SharedMemoryStoreServer SharedMemoryStoreServer
The graph store server The graph store server
""" """
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, multigraph, if multigraph is not None:
dgl_warning("multigraph is deprecated." \
"DGL treat all graphs as multigraph by default.")
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, None,
num_workers, port) num_workers, port)
def create_graph_from_store(graph_name, store_type, port=8000): def create_graph_from_store(graph_name, store_type, port=8000):
......
...@@ -390,7 +390,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None): ...@@ -390,7 +390,7 @@ def hetero_from_relations(rel_graphs, num_nodes_per_type=None):
meta_edges_src.append(ntype_dict[stype]) meta_edges_src.append(ntype_dict[stype])
meta_edges_dst.append(ntype_dict[dtype]) meta_edges_dst.append(ntype_dict[dtype])
etypes.append(etype) etypes.append(etype)
metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True, True) metagraph = graph_index.from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
# create graph index # create graph index
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
......
...@@ -315,7 +315,7 @@ class DGLBaseGraph(object): ...@@ -315,7 +315,7 @@ class DGLBaseGraph(object):
""" """
return self._graph.successors(v).tousertensor() return self._graph.successors(v).tousertensor()
def edge_id(self, u, v, force_multi=False): def edge_id(self, u, v, force_multi=None, return_array=False):
"""Return the edge ID, or an array of edge IDs, between source node """Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`. `u` and destination node `v`.
...@@ -326,15 +326,24 @@ class DGLBaseGraph(object): ...@@ -326,15 +326,24 @@ class DGLBaseGraph(object):
v : int v : int
The destination node ID. The destination node ID.
force_multi : bool force_multi : bool
If False, will return a single edge ID if the graph is a simple graph. Deprecated (Will be deleted in the future).
If False, will return a single edge ID.
If True, will always return an array.
return_array : bool
If False, will return a single edge ID.
If True, will always return an array. If True, will always return an array.
Returns Returns
------- -------
int or tensor int or tensor
The edge ID if force_multi == True and the graph is a simple graph. The edge ID if return_array is False.
The edge ID array otherwise. The edge ID array otherwise.
Notes
-----
If multiply edges exist between `u` and `v` and return_array is False,
the result is undefined.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -351,14 +360,14 @@ class DGLBaseGraph(object): ...@@ -351,14 +360,14 @@ class DGLBaseGraph(object):
For multigraphs: For multigraphs:
>>> G = dgl.DGLGraph(multigraph=True) >>> G = dgl.DGLGraph()
>>> G.add_nodes(3) >>> G.add_nodes(3)
Adding edges (0, 1), (0, 2), (0, 1), (0, 2), so edge ID 0 and 2 both Adding edges (0, 1), (0, 2), (0, 1), (0, 2), so edge ID 0 and 2 both
connect from 0 and 1, while edge ID 1 and 3 both connect from 0 and 2. connect from 0 and 1, while edge ID 1 and 3 both connect from 0 and 2.
>>> G.add_edges([0, 0, 0, 0], [1, 2, 1, 2]) >>> G.add_edges([0, 0, 0, 0], [1, 2, 1, 2])
>>> G.edge_id(0, 1) >>> G.edge_id(0, 1, return_array=True)
tensor([0, 2]) tensor([0, 2])
See Also See Also
...@@ -366,9 +375,20 @@ class DGLBaseGraph(object): ...@@ -366,9 +375,20 @@ class DGLBaseGraph(object):
edge_ids edge_ids
""" """
idx = self._graph.edge_id(u, v) idx = self._graph.edge_id(u, v)
return idx.tousertensor() if force_multi or self.is_multigraph else idx[0] if force_multi is not None:
dgl_warning("force_multi will be deprecated." \
"Please use return_array instead")
return_array = force_multi
def edge_ids(self, u, v, force_multi=False): if return_array:
return idx.tousertensor()
else:
assert len(idx) == 1, "For return_array=False, there should be one and " \
"only one edge between u and v, but get {} edges. " \
"Please use return_array=True instead".format(len(idx))
return idx[0]
def edge_ids(self, u, v, force_multi=None, return_uv=False):
"""Return all edge IDs between source node array `u` and destination """Return all edge IDs between source node array `u` and destination
node array `v`. node array `v`.
...@@ -379,21 +399,26 @@ class DGLBaseGraph(object): ...@@ -379,21 +399,26 @@ class DGLBaseGraph(object):
v : list, tensor v : list, tensor
The destination node ID array. The destination node ID array.
force_multi : bool force_multi : bool
Deprecated (Will be deleted in the future).
Whether to always treat the graph as a multigraph. Whether to always treat the graph as a multigraph.
return_uv : bool
Whether return e or (eu, ev, e)
Returns Returns
------- -------
tensor, or (tensor, tensor, tensor) tensor, or (tensor, tensor, tensor)
If the graph is a simple graph and `force_multi` is False, return If 'return_uv` is False, return a single edge ID array `e`.
a single edge ID array `e`. `e[i]` is the edge ID between `u[i]` `e[i]` is the edge ID between `u[i]` and `v[i]`.
and `v[i]`.
Otherwise, return three arrays `(eu, ev, e)`. `e[i]` is the ID Otherwise, return three arrays `(eu, ev, e)`. `e[i]` is the ID
of an edge between `eu[i]` and `ev[i]`. All edges between `u[i]` of an edge between `eu[i]` and `ev[i]`. All edges between `u[i]`
and `v[i]` are returned. and `v[i]` are returned.
Notes Notes
----- -----
If the graph is a simple graph, `force_multi` is False, and no edge If the graph is a simple graph, `return_uv` is False, and no edge
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
If the graph is a multi graph, `return_uv` is False, and multi edges
exist between some pairs of `u[i]` and `v[i]`, the result is undefined. exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Examples Examples
...@@ -411,14 +436,14 @@ class DGLBaseGraph(object): ...@@ -411,14 +436,14 @@ class DGLBaseGraph(object):
For multigraphs For multigraphs
>>> G = dgl.DGLGraph(multigraph=True) >>> G = dgl.DGLGraph()
>>> G.add_nodes(4) >>> G.add_nodes(4)
>>> G.add_edges([0, 0, 0], [1, 1, 2]) # (0, 1), (0, 1), (0, 2) >>> G.add_edges([0, 0, 0], [1, 1, 2]) # (0, 1), (0, 1), (0, 2)
Get all edges between (0, 1), (0, 2), (0, 3). Note that there is no Get all edges between (0, 1), (0, 2), (0, 3). Note that there is no
edge between 0 and 3: edge between 0 and 3:
>>> G.edge_ids([0, 0, 0], [1, 2, 3]) >>> G.edge_ids([0, 0, 0], [1, 2, 3], return_uv=True)
(tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([0, 1, 2])) (tensor([0, 0, 0]), tensor([1, 1, 2]), tensor([0, 1, 2]))
See Also See Also
...@@ -428,9 +453,17 @@ class DGLBaseGraph(object): ...@@ -428,9 +453,17 @@ class DGLBaseGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(u, v) src, dst, eid = self._graph.edge_ids(u, v)
if force_multi or self.is_multigraph: if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \
"Please use return_uv instead")
return_uv = force_multi
if return_uv:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else: else:
assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
"only one edge between each u and v, expect {} edges but get {}. " \
"Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
return eid.tousertensor() return eid.tousertensor()
def find_edges(self, eid): def find_edges(self, eid):
...@@ -814,8 +847,9 @@ class DGLGraph(DGLBaseGraph): ...@@ -814,8 +847,9 @@ class DGLGraph(DGLBaseGraph):
edge_frame : FrameRef, optional edge_frame : FrameRef, optional
Edge feature storage. Edge feature storage.
multigraph : bool, optional multigraph : bool, optional
Whether the graph would be a multigraph. If none, the flag will be determined Deprecated (Will be deleted in the future).
by scanning the whole graph. (default: None) Whether the graph would be a multigraph. If none, the flag will be
set to True. (default: None)
readonly : bool, optional readonly : bool, optional
Whether the graph structure is read-only (default: False). Whether the graph structure is read-only (default: False).
...@@ -956,7 +990,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -956,7 +990,10 @@ class DGLGraph(DGLBaseGraph):
if sort_csr: if sort_csr:
gidx.sort_csr() gidx.sort_csr()
else: else:
gidx = graph_index.create_graph_index(graph_data, multigraph, readonly) if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
gidx = graph_index.create_graph_index(graph_data, readonly)
if sort_csr: if sort_csr:
gidx.sort_csr() gidx.sort_csr()
super(DGLGraph, self).__init__(gidx) super(DGLGraph, self).__init__(gidx)
...@@ -1805,7 +1842,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1805,7 +1842,7 @@ class DGLGraph(DGLBaseGraph):
raise DGLError('Not all edges have attribute {}.'.format(attr)) raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr]) self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, spmat, multigraph=False): def from_scipy_sparse_matrix(self, spmat, multigraph=None):
""" Convert from scipy sparse matrix. """ Convert from scipy sparse matrix.
Parameters Parameters
...@@ -1814,6 +1851,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1814,6 +1851,7 @@ class DGLGraph(DGLBaseGraph):
The graph's adjacency matrix The graph's adjacency matrix
multigraph : bool, optional multigraph : bool, optional
Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph. If the input scipy sparse matrix is CSR, Whether the graph would be a multigraph. If the input scipy sparse matrix is CSR,
this argument is ignored. this argument is ignored.
...@@ -1828,7 +1866,11 @@ class DGLGraph(DGLBaseGraph): ...@@ -1828,7 +1866,11 @@ class DGLGraph(DGLBaseGraph):
>>> g.from_scipy_sparse_matrix(a) >>> g.from_scipy_sparse_matrix(a)
""" """
self.clear() self.clear()
self._graph = graph_index.from_scipy_sparse_matrix(spmat, multigraph, self.is_readonly) if multigraph is not None:
dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.")
self._graph = graph_index.from_scipy_sparse_matrix(spmat, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges()) self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges()) self._msg_frame.add_rows(self.number_of_edges())
......
...@@ -35,7 +35,6 @@ class GraphIndex(ObjectBase): ...@@ -35,7 +35,6 @@ class GraphIndex(ObjectBase):
""" """
def __new__(cls): def __new__(cls):
obj = ObjectBase.__new__(cls) obj = ObjectBase.__new__(cls)
obj._multigraph = None # python-side cache of the flag
obj._readonly = None # python-side cache of the flag obj._readonly = None # python-side cache of the flag
obj._cache = {} obj._cache = {}
return obj return obj
...@@ -43,28 +42,22 @@ class GraphIndex(ObjectBase): ...@@ -43,28 +42,22 @@ class GraphIndex(ObjectBase):
def __getstate__(self): def __getstate__(self):
src, dst, _ = self.edges() src, dst, _ = self.edges()
n_nodes = self.number_of_nodes() n_nodes = self.number_of_nodes()
# TODO(minjie): should try to avoid calling is_multigraph
multigraph = self.is_multigraph()
readonly = self.is_readonly() readonly = self.is_readonly()
return n_nodes, multigraph, readonly, src, dst return n_nodes, readonly, src, dst
def __setstate__(self, state): def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet """The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, readonly, src_nodes, dst_nodes) (number_of_nodes, readonly, src_nodes, dst_nodes)
""" """
num_nodes, multigraph, readonly, src, dst = state num_nodes, readonly, src, dst = state
self._cache = {} self._cache = {}
self._multigraph = multigraph
self._readonly = readonly self._readonly = readonly
if multigraph is None:
multigraph = BoolFlag.BOOL_UNKNOWN
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_CAPI_DGLGraphCreate, _CAPI_DGLGraphCreate,
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(multigraph),
int(num_nodes), int(num_nodes),
readonly) readonly)
...@@ -118,15 +111,14 @@ class GraphIndex(ObjectBase): ...@@ -118,15 +111,14 @@ class GraphIndex(ObjectBase):
def is_multigraph(self): def is_multigraph(self):
"""Return whether the graph is a multigraph """Return whether the graph is a multigraph
The time cost will be O(E)
Returns Returns
------- -------
bool bool
True if it is a multigraph, False otherwise. True if it is a multigraph, False otherwise.
""" """
if self._multigraph is None: return bool(_CAPI_DGLGraphIsMultigraph(self))
self._multigraph = bool(_CAPI_DGLGraphIsMultigraph(self))
return self._multigraph
def is_readonly(self): def is_readonly(self):
"""Indicate whether the graph index is read-only. """Indicate whether the graph index is read-only.
...@@ -149,9 +141,9 @@ class GraphIndex(ObjectBase): ...@@ -149,9 +141,9 @@ class GraphIndex(ObjectBase):
New readonly state of current graph index. New readonly state of current graph index.
""" """
# TODO(minjie): very ugly code, should fix this # TODO(minjie): very ugly code, should fix this
n_nodes, multigraph, _, src, dst = self.__getstate__() n_nodes, _, src, dst = self.__getstate__()
self.clear_cache() self.clear_cache()
state = (n_nodes, multigraph, readonly_state, src, dst) state = (n_nodes, readonly_state, src, dst)
self.__setstate__(state) self.__setstate__(state)
def number_of_nodes(self): def number_of_nodes(self):
...@@ -829,7 +821,8 @@ class GraphIndex(ObjectBase): ...@@ -829,7 +821,8 @@ class GraphIndex(ObjectBase):
The nx graph The nx graph
""" """
src, dst, eid = self.edges() src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph() # xiangsx: Always treat graph as multigraph
ret = nx.MultiDiGraph()
ret.add_nodes_from(range(self.number_of_nodes())) ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, e in zip(src, dst, eid): for u, v, e in zip(src, dst, eid):
ret.add_edge(u, v, id=e) ret.add_edge(u, v, id=e)
...@@ -991,7 +984,7 @@ class SubgraphIndex(ObjectBase): ...@@ -991,7 +984,7 @@ class SubgraphIndex(ObjectBase):
############################################################### ###############################################################
# Conversion functions # Conversion functions
############################################################### ###############################################################
def from_coo(num_nodes, src, dst, is_multigraph, readonly): def from_coo(num_nodes, src, dst, readonly):
"""Convert from coo arrays. """Convert from coo arrays.
Parameters Parameters
...@@ -1002,8 +995,6 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly): ...@@ -1002,8 +995,6 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly):
Src end nodes of the edges. Src end nodes of the edges.
dst : Tensor dst : Tensor
Dst end nodes of the edges. Dst end nodes of the edges.
is_multigraph : bool or None
True if the graph is a multigraph. None means determined by data.
readonly : bool readonly : bool
True if the returned graph is readonly. True if the returned graph is readonly.
...@@ -1014,25 +1005,19 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly): ...@@ -1014,25 +1005,19 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly):
""" """
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN
if readonly: if readonly:
gidx = _CAPI_DGLGraphCreate( gidx = _CAPI_DGLGraphCreate(
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(is_multigraph),
int(num_nodes), int(num_nodes),
readonly) readonly)
else: else:
if is_multigraph is BoolFlag.BOOL_UNKNOWN: gidx = _CAPI_DGLGraphCreateMutable()
# TODO(minjie): better behavior in the future
is_multigraph = BoolFlag.BOOL_FALSE
gidx = _CAPI_DGLGraphCreateMutable(bool(is_multigraph))
gidx.add_nodes(num_nodes) gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst) gidx.add_edges(src, dst)
return gidx return gidx
def from_csr(indptr, indices, is_multigraph, def from_csr(indptr, indices,
direction, shared_mem_name=""): direction, shared_mem_name=""):
"""Load a graph from CSR arrays. """Load a graph from CSR arrays.
...@@ -1042,8 +1027,6 @@ def from_csr(indptr, indices, is_multigraph, ...@@ -1042,8 +1027,6 @@ def from_csr(indptr, indices, is_multigraph,
index pointer in the CSR format index pointer in the CSR format
indices : Tensor indices : Tensor
column index array in the CSR format column index array in the CSR format
is_multigraph : bool or None
True if the graph is a multigraph. None means determined by data.
direction : str direction : str
the edge direction. Either "in" or "out". the edge direction. Either "in" or "out".
shared_mem_name : str shared_mem_name : str
...@@ -1051,19 +1034,15 @@ def from_csr(indptr, indices, is_multigraph, ...@@ -1051,19 +1034,15 @@ def from_csr(indptr, indices, is_multigraph,
""" """
indptr = utils.toindex(indptr) indptr = utils.toindex(indptr)
indices = utils.toindex(indices) indices = utils.toindex(indices)
if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN
gidx = _CAPI_DGLGraphCSRCreate( gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(),
indices.todgltensor(), indices.todgltensor(),
shared_mem_name, shared_mem_name,
int(is_multigraph),
direction) direction)
return gidx return gidx
def from_shared_mem_csr_matrix(shared_mem_name, def from_shared_mem_csr_matrix(shared_mem_name,
num_nodes, num_edges, edge_dir, num_nodes, num_edges, edge_dir):
is_multigraph):
"""Load a graph from the shared memory in the CSR format. """Load a graph from the shared memory in the CSR format.
Parameters Parameters
...@@ -1080,7 +1059,6 @@ def from_shared_mem_csr_matrix(shared_mem_name, ...@@ -1080,7 +1059,6 @@ def from_shared_mem_csr_matrix(shared_mem_name,
gidx = _CAPI_DGLGraphCSRCreateMMap( gidx = _CAPI_DGLGraphCSRCreateMMap(
shared_mem_name, shared_mem_name,
int(num_nodes), int(num_edges), int(num_nodes), int(num_edges),
is_multigraph,
edge_dir) edge_dir)
return gidx return gidx
...@@ -1109,8 +1087,6 @@ def from_networkx(nx_graph, readonly): ...@@ -1109,8 +1087,6 @@ def from_networkx(nx_graph, readonly):
# to_directed creates a deep copy of the networkx graph even if # to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it. # the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed() nx_graph = nx_graph.to_directed()
is_multigraph = isinstance(nx_graph, nx.MultiDiGraph)
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
# nx_graph.edges(data=True) returns src, dst, attr_dict # nx_graph.edges(data=True) returns src, dst, attr_dict
...@@ -1137,17 +1113,14 @@ def from_networkx(nx_graph, readonly): ...@@ -1137,17 +1113,14 @@ def from_networkx(nx_graph, readonly):
# We store edge Ids as an edge attribute. # We store edge Ids as an edge attribute.
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, is_multigraph, readonly) return from_coo(num_nodes, src, dst, readonly)
def from_scipy_sparse_matrix(adj, multigraph, readonly): def from_scipy_sparse_matrix(adj, readonly):
"""Convert from scipy sparse matrix. """Convert from scipy sparse matrix.
Parameters Parameters
---------- ----------
adj : scipy sparse matrix adj : scipy sparse matrix
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool readonly : bool
True if the returned graph is readonly. True if the returned graph is readonly.
...@@ -1159,12 +1132,12 @@ def from_scipy_sparse_matrix(adj, multigraph, readonly): ...@@ -1159,12 +1132,12 @@ def from_scipy_sparse_matrix(adj, multigraph, readonly):
if adj.getformat() != 'csr' or not readonly: if adj.getformat() != 'csr' or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1]) num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo() adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, multigraph, readonly) return from_coo(num_nodes, adj_coo.row, adj_coo.col, readonly)
else: else:
# If the input matrix is csr, it's guaranteed to be a simple graph. # If the input matrix is csr, we still treat it as multigraph.
return from_csr(adj.indptr, adj.indices, False, "out") return from_csr(adj.indptr, adj.indices, "out")
def from_edge_list(elist, is_multigraph, readonly): def from_edge_list(elist, readonly):
"""Convert from an edge list. """Convert from an edge list.
Parameters Parameters
...@@ -1184,7 +1157,7 @@ def from_edge_list(elist, is_multigraph, readonly): ...@@ -1184,7 +1157,7 @@ def from_edge_list(elist, is_multigraph, readonly):
min_nodes = min(src.min(), dst.min()) min_nodes = min(src.min(), dst.min())
if min_nodes != 0: if min_nodes != 0:
raise DGLError('Invalid edge list. Nodes must start from 0.') raise DGLError('Invalid edge list. Nodes must start from 0.')
return from_coo(num_nodes, src_ids, dst_ids, is_multigraph, readonly) return from_coo(num_nodes, src_ids, dst_ids, readonly)
def map_to_subgraph_nid(induced_nodes, parent_nids): def map_to_subgraph_nid(induced_nodes, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
...@@ -1274,16 +1247,13 @@ def disjoint_partition(graph, num_or_size_splits): ...@@ -1274,16 +1247,13 @@ def disjoint_partition(graph, num_or_size_splits):
int(num_or_size_splits)) int(num_or_size_splits))
return rst return rst
def create_graph_index(graph_data, multigraph, readonly): def create_graph_index(graph_data, readonly):
"""Create a graph index object. """Create a graph index object.
Parameters Parameters
---------- ----------
graph_data : graph data graph_data : graph data
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool readonly : bool
Whether the graph structure is read-only. Whether the graph structure is read-only.
""" """
...@@ -1294,15 +1264,13 @@ def create_graph_index(graph_data, multigraph, readonly): ...@@ -1294,15 +1264,13 @@ def create_graph_index(graph_data, multigraph, readonly):
if graph_data is None: if graph_data is None:
if readonly: if readonly:
raise Exception("can't create an empty immutable graph") raise Exception("can't create an empty immutable graph")
if multigraph is None: return _CAPI_DGLGraphCreateMutable()
multigraph = False
return _CAPI_DGLGraphCreateMutable(multigraph)
elif isinstance(graph_data, (list, tuple)): elif isinstance(graph_data, (list, tuple)):
# edge list # edge list
return from_edge_list(graph_data, multigraph, readonly) return from_edge_list(graph_data, readonly)
elif isinstance(graph_data, scipy.sparse.spmatrix): elif isinstance(graph_data, scipy.sparse.spmatrix):
# scipy format # scipy format
return from_scipy_sparse_matrix(graph_data, multigraph, readonly) return from_scipy_sparse_matrix(graph_data, readonly)
else: else:
# networkx - any format # networkx - any format
try: try:
......
...@@ -266,8 +266,6 @@ class DGLHeteroGraph(object): ...@@ -266,8 +266,6 @@ class DGLHeteroGraph(object):
frame.set_initializer(init.zero_initializer) frame.set_initializer(init.zero_initializer)
self._msg_frames.append(frame) self._msg_frames.append(frame)
self._is_multigraph = None
def __getstate__(self): def __getstate__(self):
return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames return self._graph, self._ntypes, self._etypes, self._node_frames, self._edge_frames
...@@ -1075,10 +1073,7 @@ class DGLHeteroGraph(object): ...@@ -1075,10 +1073,7 @@ class DGLHeteroGraph(object):
bool bool
True if the graph is a multigraph, False otherwise. True if the graph is a multigraph, False otherwise.
""" """
if self._is_multigraph is None: return self._graph.is_multigraph()
return self._graph.is_multigraph()
else:
return self._is_multigraph
@property @property
def is_readonly(self): def is_readonly(self):
...@@ -1295,7 +1290,7 @@ class DGLHeteroGraph(object): ...@@ -1295,7 +1290,7 @@ class DGLHeteroGraph(object):
""" """
return self._graph.successors(self.get_etype_id(etype), v).tousertensor() return self._graph.successors(self.get_etype_id(etype), v).tousertensor()
def edge_id(self, u, v, force_multi=False, etype=None): def edge_id(self, u, v, force_multi=None, return_array=False, etype=None):
"""Return the edge ID, or an array of edge IDs, between source node """Return the edge ID, or an array of edge IDs, between source node
`u` and destination node `v`, with the specified edge type `u` and destination node `v`, with the specified edge type
...@@ -1306,7 +1301,11 @@ class DGLHeteroGraph(object): ...@@ -1306,7 +1301,11 @@ class DGLHeteroGraph(object):
v : int v : int
The node ID of destination type. The node ID of destination type.
force_multi : bool, optional force_multi : bool, optional
If False, will return a single edge ID if the graph is a simple graph. Deprecated (Will be deleted in the future).
If False, will return a single edge ID.
If True, will always return an array. (Default: False)
return_array : bool, optional
If False, will return a single edge ID.
If True, will always return an array. (Default: False) If True, will always return an array. (Default: False)
etype : str or tuple of str, optional etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
...@@ -1315,9 +1314,14 @@ class DGLHeteroGraph(object): ...@@ -1315,9 +1314,14 @@ class DGLHeteroGraph(object):
Returns Returns
------- -------
int or tensor int or tensor
The edge ID if ``force_multi == True`` and the graph is a simple graph. The edge ID if ``return_array == False``.
The edge ID array otherwise. The edge ID array otherwise.
Notes
-----
If multiply edges exist between `u` and `v` and return_array is False,
the result is undefined.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -1332,7 +1336,7 @@ class DGLHeteroGraph(object): ...@@ -1332,7 +1336,7 @@ class DGLHeteroGraph(object):
>>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game')) >>> plays_g.edge_id(1, 2, etype=('user', 'plays', 'game'))
2 2
>>> g.edge_id(1, 2, force_multi=True, etype=('user', 'follows', 'user')) >>> g.edge_id(1, 2, return_array=True, etype=('user', 'follows', 'user'))
tensor([1, 2]) tensor([1, 2])
See Also See Also
...@@ -1340,9 +1344,20 @@ class DGLHeteroGraph(object): ...@@ -1340,9 +1344,20 @@ class DGLHeteroGraph(object):
edge_ids edge_ids
""" """
idx = self._graph.edge_id(self.get_etype_id(etype), u, v) idx = self._graph.edge_id(self.get_etype_id(etype), u, v)
return idx.tousertensor() if force_multi or self._graph.is_multigraph() else idx[0] if force_multi is not None:
dgl_warning("force_multi will be deprecated." \
"Please use return_array instead")
return_array = force_multi
if return_array:
return idx.tousertensor()
else:
assert len(idx) == 1, "For return_array=False, there should be one and " \
"only one edge between u and v, but get {} edges. " \
"Please use return_array=True instead".format(len(idx))
return idx[0]
def edge_ids(self, u, v, force_multi=False, etype=None): def edge_ids(self, u, v, force_multi=None, return_uv=False, etype=None):
"""Return all edge IDs between source node array `u` and destination """Return all edge IDs between source node array `u` and destination
node array `v` with the specified edge type. node array `v` with the specified edge type.
...@@ -1353,8 +1368,11 @@ class DGLHeteroGraph(object): ...@@ -1353,8 +1368,11 @@ class DGLHeteroGraph(object):
v : list, tensor v : list, tensor
The node ID array of destination type. The node ID array of destination type.
force_multi : bool, optional force_multi : bool, optional
Deprecated (Will be deleted in the future).
Whether to always treat the graph as a multigraph. See the Whether to always treat the graph as a multigraph. See the
"Returns" for their effects. (Default: False) "Returns" for their effects. (Default: False)
return_uv : bool
See the "Returns" for their effects. (Default: False)
etype : str or tuple of str, optional etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type The edge type. Can be omitted if there is only one edge type
in the graph. in the graph.
...@@ -1363,9 +1381,8 @@ class DGLHeteroGraph(object): ...@@ -1363,9 +1381,8 @@ class DGLHeteroGraph(object):
------- -------
tensor, or (tensor, tensor, tensor) tensor, or (tensor, tensor, tensor)
* If the graph is a simple graph and ``force_multi=False``, return * If ``return_uv=False``, return a single edge ID array ``e``.
a single edge ID array ``e``. ``e[i]`` is the edge ID between ``u[i]`` ``e[i]`` is the edge ID between ``u[i]`` and ``v[i]``.
and ``v[i]``.
* Otherwise, return three arrays ``(eu, ev, e)``. ``e[i]`` is the ID * Otherwise, return three arrays ``(eu, ev, e)``. ``e[i]`` is the ID
of an edge between ``eu[i]`` and ``ev[i]``. All edges between ``u[i]`` of an edge between ``eu[i]`` and ``ev[i]``. All edges between ``u[i]``
...@@ -1373,10 +1390,13 @@ class DGLHeteroGraph(object): ...@@ -1373,10 +1390,13 @@ class DGLHeteroGraph(object):
Notes Notes
----- -----
If the graph is a simple graph, ``force_multi=False``, and no edge If the graph is a simple graph, ``return_uv=False``, and no edge
exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined exists between some pairs of ``u[i]`` and ``v[i]``, the result is undefined
and an empty tensor is returned. and an empty tensor is returned.
If the graph is a multi graph, ``return_uv=False``, and multi edges
exist between some pairs of `u[i]` and `v[i]`, the result is undefined.
Examples Examples
-------- --------
The following example uses PyTorch backend. The following example uses PyTorch backend.
...@@ -1393,7 +1413,7 @@ class DGLHeteroGraph(object): ...@@ -1393,7 +1413,7 @@ class DGLHeteroGraph(object):
tensor([], dtype=torch.int64) tensor([], dtype=torch.int64)
>>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game')) >>> plays_g.edge_ids([1], [2], etype=('user', 'plays', 'game'))
tensor([2]) tensor([2])
>>> g.edge_ids([1], [2], force_multi=True, etype=('user', 'follows', 'user')) >>> g.edge_ids([1], [2], return_uv=True, etype=('user', 'follows', 'user'))
(tensor([1, 1]), tensor([2, 2]), tensor([1, 2])) (tensor([1, 1]), tensor([2, 2]), tensor([1, 2]))
See Also See Also
...@@ -1403,9 +1423,17 @@ class DGLHeteroGraph(object): ...@@ -1403,9 +1423,17 @@ class DGLHeteroGraph(object):
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v) src, dst, eid = self._graph.edge_ids(self.get_etype_id(etype), u, v)
if force_multi or self._graph.is_multigraph(): if force_multi is not None:
dgl_warning("force_multi will be deprecated, " \
"Please use return_uv instead")
return_uv = force_multi
if return_uv:
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
else: else:
assert len(eid) == max(len(u), len(v)), "If return_uv=False, there should be one and " \
"only one edge between each u and v, expect {} edges but get {}. " \
"Please use return_uv=True instead".format(max(len(u), len(v)), len(eid))
return eid.tousertensor() return eid.tousertensor()
def find_edges(self, eid, etype=None): def find_edges(self, eid, etype=None):
...@@ -2023,7 +2051,7 @@ class DGLHeteroGraph(object): ...@@ -2023,7 +2051,7 @@ class DGLHeteroGraph(object):
induced_etypes.append(self.etypes[i]) induced_etypes.append(self.etypes[i])
edge_frames.append(self._edge_frames[i]) edge_frames.append(self._edge_frames[i])
metagraph = graph_index.from_edge_list(meta_edges, True, True) metagraph = graph_index.from_edge_list(meta_edges, True)
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_type)) metagraph, rel_graphs, utils.toindex(num_nodes_per_type))
hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, ntypes, induced_etypes, node_frames, edge_frames)
...@@ -2097,7 +2125,7 @@ class DGLHeteroGraph(object): ...@@ -2097,7 +2125,7 @@ class DGLHeteroGraph(object):
induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type induced_etypes = [self._etypes[i] for i in etype_ids] # get the "name" of edge type
num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes] num_nodes_per_induced_type = [self.number_of_nodes(ntype) for ntype in induced_ntypes]
metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True, True) metagraph = graph_index.from_edge_list((mapped_meta_src, mapped_meta_dst), True)
hgidx = heterograph_index.create_heterograph_from_relations( hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type)) metagraph, rel_graphs, utils.toindex(num_nodes_per_induced_type))
hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames) hg = DGLHeteroGraph(hgidx, induced_ntypes, induced_etypes, node_frames, edge_frames)
...@@ -3787,7 +3815,8 @@ class DGLHeteroGraph(object): ...@@ -3787,7 +3815,8 @@ class DGLHeteroGraph(object):
src, dst = self.edges() src, dst = self.edges()
src = F.asnumpy(src) src = F.asnumpy(src)
dst = F.asnumpy(dst) dst = F.asnumpy(dst)
nx_graph = nx.MultiDiGraph() if self.is_multigraph else nx.DiGraph() # xiangsx: Always treat graph as multigraph
nx_graph = nx.MultiDiGraph()
nx_graph.add_nodes_from(range(self.number_of_nodes())) nx_graph.add_nodes_from(range(self.number_of_nodes()))
for eid, (u, v) in enumerate(zip(src, dst)): for eid, (u, v) in enumerate(zip(src, dst)):
nx_graph.add_edge(u, v, id=eid) nx_graph.add_edge(u, v, id=eid)
......
...@@ -216,6 +216,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -216,6 +216,7 @@ class HeteroGraphIndex(ObjectBase):
def is_multigraph(self): def is_multigraph(self):
"""Return whether the graph is a multigraph """Return whether the graph is a multigraph
The time cost will be O(E)
Returns Returns
------- -------
......
...@@ -241,7 +241,7 @@ def khop_graph(g, k): ...@@ -241,7 +241,7 @@ def khop_graph(g, k):
col = np.repeat(adj_k.col, multiplicity) col = np.repeat(adj_k.col, multiplicity)
# TODO(zihao): we should support creating multi-graph from scipy sparse matrix # TODO(zihao): we should support creating multi-graph from scipy sparse matrix
# in the future. # in the future.
return DGLGraph(from_coo(n, row, col, True, True)) return DGLGraph(from_coo(n, row, col, True))
def reverse(g, share_ndata=False, share_edata=False): def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph """Return the reverse of a graph
...@@ -310,7 +310,7 @@ def reverse(g, share_ndata=False, share_edata=False): ...@@ -310,7 +310,7 @@ def reverse(g, share_ndata=False, share_edata=False):
[2.], [2.],
[3.]]) [3.]])
""" """
g_reversed = DGLGraph(multigraph=g.is_multigraph) g_reversed = DGLGraph()
g_reversed.add_nodes(g.number_of_nodes()) g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.all_edges(order='eid') g_edges = g.all_edges(order='eid')
g_reversed.add_edges(g_edges[1], g_edges[0]) g_reversed.add_edges(g_edges[1], g_edges[0])
...@@ -357,6 +357,10 @@ def to_bidirected(g, readonly=True): ...@@ -357,6 +357,10 @@ def to_bidirected(g, readonly=True):
readonly : bool, default to be True readonly : bool, default to be True
Whether the returned bidirected graph is readonly or not. Whether the returned bidirected graph is readonly or not.
Notes
-----
Please make sure g is a single graph, otherwise the return value is undefined.
Returns Returns
------- -------
DGLGraph DGLGraph
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
namespace dgl { namespace dgl {
Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes, Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
bool multigraph): is_multigraph_(multigraph) {
CHECK(aten::IsValidIdArray(src_ids)); CHECK(aten::IsValidIdArray(src_ids));
CHECK(aten::IsValidIdArray(dst_ids)); CHECK(aten::IsValidIdArray(dst_ids));
this->AddVertices(num_nodes); this->AddVertices(num_nodes);
...@@ -42,6 +41,33 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes, ...@@ -42,6 +41,33 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes,
} }
} }
bool Graph::IsMultigraph() const {
if (num_edges_ <= 1) {
return false;
}
typedef std::pair<int64_t, int64_t> Pair;
std::vector<Pair> pairs;
pairs.reserve(num_edges_);
for (uint64_t eid = 0; eid < num_edges_; ++eid) {
pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
}
// sort according to src and dst ids
std::sort(pairs.begin(), pairs.end(),
[] (const Pair& t1, const Pair& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
});
for (uint64_t eid = 0; eid < num_edges_-1; ++eid) {
// As src and dst are all sorted, we only need to compare i and i+1
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1]))
return true;
}
return false;
}
void Graph::AddVertices(uint64_t num_vertices) { void Graph::AddVertices(uint64_t num_vertices) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
adjlist_.resize(adjlist_.size() + num_vertices); adjlist_.resize(adjlist_.size() + num_vertices);
...@@ -447,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -447,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
oldv2newv[vid_data[i]] = i; oldv2newv[vid_data[i]] = i;
} }
Subgraph rst; Subgraph rst;
rst.graph = std::make_shared<Graph>(IsMultigraph()); rst.graph = std::make_shared<Graph>();
rst.induced_vertices = vids; rst.induced_vertices = vids;
rst.graph->AddVertices(len); rst.graph->AddVertices(len);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
...@@ -486,7 +512,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -486,7 +512,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
nodes.push_back(dst_id); nodes.push_back(dst_id);
} }
rst.graph = std::make_shared<Graph>(IsMultigraph()); rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids; rst.induced_edges = eids;
rst.graph->AddVertices(nodes.size()); rst.graph->AddVertices(nodes.size());
...@@ -500,7 +526,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -500,7 +526,7 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
} else { } else {
rst.graph = std::make_shared<Graph>(IsMultigraph()); rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids; rst.induced_edges = eids;
rst.graph->AddVertices(NumVertices()); rst.graph->AddVertices(NumVertices());
......
...@@ -23,8 +23,7 @@ namespace dgl { ...@@ -23,8 +23,7 @@ namespace dgl {
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
bool multigraph = args[0]; *rv = GraphRef(Graph::Create());
*rv = GraphRef(Graph::Create(multigraph));
}); });
...@@ -32,18 +31,12 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate") ...@@ -32,18 +31,12 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray src_ids = args[0]; const IdArray src_ids = args[0];
const IdArray dst_ids = args[1]; const IdArray dst_ids = args[1];
const int multigraph = args[2]; const int64_t num_nodes = args[2];
const int64_t num_nodes = args[3]; const bool readonly = args[3];
const bool readonly = args[4];
if (readonly) { if (readonly) {
if (multigraph == kBoolUnknown) { *rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
}
} else { } else {
CHECK_NE(multigraph, kBoolUnknown); *rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids, multigraph));
} }
}); });
...@@ -52,8 +45,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -52,8 +45,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
const IdArray indptr = args[0]; const IdArray indptr = args[0];
const IdArray indices = args[1]; const IdArray indices = args[1];
const std::string shared_mem_name = args[2]; const std::string shared_mem_name = args[2];
const int multigraph = args[3]; const std::string edge_dir = args[3];
const std::string edge_dir = args[4];
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
...@@ -61,20 +53,10 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -61,20 +53,10 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
for (size_t i = 0; i < edge_ids->shape[0]; i++) for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i; edge_data[i] = i;
if (shared_mem_name.empty()) { if (shared_mem_name.empty()) {
if (multigraph == kBoolUnknown) { *rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, multigraph, edge_dir));
}
} else { } else {
if (multigraph == kBoolUnknown) { *rv = GraphRef(ImmutableGraph::CreateFromCSR(
*rv = GraphRef(ImmutableGraph::CreateFromCSR( indptr, indices, edge_ids, edge_dir, shared_mem_name));
indptr, indices, edge_ids, edge_dir, shared_mem_name));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir, shared_mem_name));
}
} }
}); });
...@@ -83,11 +65,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") ...@@ -83,11 +65,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
const std::string shared_mem_name = args[0]; const std::string shared_mem_name = args[0];
const int64_t num_vertices = args[1]; const int64_t num_vertices = args[1];
const int64_t num_edges = args[2]; const int64_t num_edges = args[2];
const bool multigraph = args[3]; const std::string edge_dir = args[3];
const std::string edge_dir = args[4];
// TODO(minjie): how to know multigraph
*rv = GraphRef(ImmutableGraph::CreateFromCSR( *rv = GraphRef(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, multigraph, edge_dir)); shared_mem_name, num_vertices, num_edges, edge_dir));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
......
...@@ -325,7 +325,7 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) { ...@@ -325,7 +325,7 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
indptr[src+1] = indices.size(); indptr[src+1] = indices.size();
} }
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(), CSRPtr csr(new CSR(graph->NumVertices(), indices.size(),
indptr.begin(), indices.begin(), RangeIter(0), false)); indptr.begin(), indices.begin(), RangeIter(0)));
return std::make_shared<ImmutableGraph>(csr); return std::make_shared<ImmutableGraph>(csr);
} }
...@@ -397,7 +397,7 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) { ...@@ -397,7 +397,7 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
IdArray srcs_array = aten::VecToIdArray(srcs); IdArray srcs_array = aten::VecToIdArray(srcs);
IdArray dsts_array = aten::VecToIdArray(dsts); IdArray dsts_array = aten::VecToIdArray(dsts);
return ImmutableGraph::CreateFromCOO( return ImmutableGraph::CreateFromCOO(
g->NumVertices(), srcs_array, dsts_array, g->IsMultigraph()); g->NumVertices(), srcs_array, dsts_array);
} }
HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) { HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) {
......
...@@ -187,14 +187,12 @@ HeteroGraph::HeteroGraph( ...@@ -187,14 +187,12 @@ HeteroGraph::HeteroGraph(
} }
bool HeteroGraph::IsMultigraph() const { bool HeteroGraph::IsMultigraph() const {
return const_cast<HeteroGraph*>(this)->is_multigraph_.Get([this] () { for (const auto &hg : relation_graphs_) {
for (const auto &hg : relation_graphs_) { if (hg->IsMultigraph()) {
if (hg->IsMultigraph()) { return true;
return true; }
} }
} return false;
return false;
});
} }
BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
......
...@@ -214,9 +214,6 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -214,9 +214,6 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief A map from vert type to the number of verts in the type */ /*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_; std::vector<int64_t> num_verts_per_type_;
/*! \brief True if the graph is a multigraph */
Lazy<bool> is_multigraph_;
}; };
} // namespace dgl } // namespace dgl
......
...@@ -55,8 +55,7 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( ...@@ -55,8 +55,7 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
CSR::CSR(int64_t num_vertices, int64_t num_edges, bool is_multigraph) CSR::CSR(int64_t num_vertices, int64_t num_edges) {
: is_multigraph_(is_multigraph) {
CHECK(!(num_vertices == 0 && num_edges != 0)); CHECK(!(num_vertices == 0 && num_edges != 0));
adj_ = aten::CSRMatrix{num_vertices, num_vertices, adj_ = aten::CSRMatrix{num_vertices, num_vertices,
aten::NewIdArray(num_vertices + 1), aten::NewIdArray(num_vertices + 1),
...@@ -75,17 +74,6 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { ...@@ -75,17 +74,6 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
adj_.sorted = false; adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t N = indptr->shape[0] - 1;
adj_ = aten::CSRMatrix{N, N, indptr, indices, edge_ids};
adj_.sorted = false;
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) { const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
...@@ -105,29 +93,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -105,29 +93,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
adj_.sorted = false; adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name): is_multigraph_(is_multigraph),
shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]);
const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0];
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays
adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices);
adj_.data.CopyFrom(edge_ids);
adj_.sorted = false;
}
CSR::CSR(const std::string &shared_mem_name, CSR::CSR(const std::string &shared_mem_name,
int64_t num_verts, int64_t num_edges, bool is_multigraph) int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) {
: is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
CHECK(!(num_verts == 0 && num_edges != 0)); CHECK(!(num_verts == 0 && num_edges != 0));
adj_.num_rows = num_verts; adj_.num_rows = num_verts;
adj_.num_cols = num_verts; adj_.num_cols = num_verts;
...@@ -137,10 +104,7 @@ CSR::CSR(const std::string &shared_mem_name, ...@@ -137,10 +104,7 @@ CSR::CSR(const std::string &shared_mem_name,
} }
bool CSR::IsMultigraph() const { bool CSR::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag. return aten::CSRHasDuplicate(adj_);
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () {
return aten::CSRHasDuplicate(adj_);
});
} }
EdgeArray CSR::OutEdges(dgl_id_t vid) const { EdgeArray CSR::OutEdges(dgl_id_t vid) const {
...@@ -233,7 +197,6 @@ CSR CSR::CopyTo(const DLContext& ctx) const { ...@@ -233,7 +197,6 @@ CSR CSR::CopyTo(const DLContext& ctx) const {
CSR ret(adj_.indptr.CopyTo(ctx), CSR ret(adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx), adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx)); adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
...@@ -255,7 +218,6 @@ CSR CSR::AsNumBits(uint8_t bits) const { ...@@ -255,7 +218,6 @@ CSR CSR::AsNumBits(uint8_t bits) const {
CSR ret(aten::AsNumBits(adj_.indptr, bits), CSR ret(aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits), aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits)); aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
...@@ -301,19 +263,8 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst) { ...@@ -301,19 +263,8 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst) {
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst}; adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
} }
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, bool is_multigraph)
: is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]);
adj_ = aten::COOMatrix{num_vertices, num_vertices, src, dst};
}
bool COO::IsMultigraph() const { bool COO::IsMultigraph() const {
// The lambda will be called the first time to initialize the is_multigraph flag. return aten::COOHasDuplicate(adj_);
return const_cast<COO*>(this)->is_multigraph_.Get([this] () {
return aten::COOHasDuplicate(adj_);
});
} }
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
...@@ -347,12 +298,12 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -347,12 +298,12 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
IdArray new_dst = aten::IndexSelect(adj_.col, eids); IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Relabel_({new_src, new_dst}); induced_nodes = aten::Relabel_({new_src, new_dst});
const auto new_nnodes = induced_nodes->shape[0]; const auto new_nnodes = induced_nodes->shape[0];
subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst, this->IsMultigraph())); subcoo = COOPtr(new COO(new_nnodes, new_src, new_dst));
} else { } else {
IdArray new_src = aten::IndexSelect(adj_.row, eids); IdArray new_src = aten::IndexSelect(adj_.row, eids);
IdArray new_dst = aten::IndexSelect(adj_.col, eids); IdArray new_dst = aten::IndexSelect(adj_.col, eids);
induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context()); induced_nodes = aten::Range(0, NumVertices(), NumBits(), Context());
subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst, this->IsMultigraph())); subcoo = COOPtr(new COO(NumVertices(), new_src, new_dst));
} }
Subgraph subg; Subgraph subg;
subg.graph = subcoo; subg.graph = subcoo;
...@@ -373,7 +324,6 @@ COO COO::CopyTo(const DLContext& ctx) const { ...@@ -373,7 +324,6 @@ COO COO::CopyTo(const DLContext& ctx) const {
COO ret(NumVertices(), COO ret(NumVertices(),
adj_.row.CopyTo(ctx), adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx)); adj_.col.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
...@@ -390,7 +340,6 @@ COO COO::AsNumBits(uint8_t bits) const { ...@@ -390,7 +340,6 @@ COO COO::AsNumBits(uint8_t bits) const {
COO ret(NumVertices(), COO ret(NumVertices(),
aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits)); aten::AsNumBits(adj_.col, bits));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
...@@ -507,21 +456,7 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f ...@@ -507,21 +456,7 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) { IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids)); CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") { if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr)); return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
} else if (edge_dir == "out") { } else if (edge_dir == "out") {
...@@ -536,22 +471,7 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ...@@ -536,22 +471,7 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &edge_dir,
const std::string &shared_mem_name) { const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir))); CSRPtr csr(new CSR(indptr, indices, edge_ids,
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
}
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir))); GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") { if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name)); return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
...@@ -565,10 +485,8 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ...@@ -565,10 +485,8 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices, const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph, size_t num_edges, const std::string &edge_dir) {
const std::string &edge_dir) { CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges));
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") { if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name)); return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") { } else if (edge_dir == "out") {
...@@ -585,12 +503,6 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCOO( ...@@ -585,12 +503,6 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
return std::make_shared<ImmutableGraph>(coo); return std::make_shared<ImmutableGraph>(coo);
} }
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph) {
COOPtr coo(new COO(num_vertices, src, dst, multigraph));
return std::make_shared<ImmutableGraph>(coo);
}
ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) { ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph); ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(graph);
if (ig) { if (ig) {
......
...@@ -355,7 +355,7 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -355,7 +355,7 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<std::pair<dgl_id_t, int> > *sub_vers, std::vector<std::pair<dgl_id_t, int> > *sub_vers,
std::vector<neighbor_info> *neigh_pos, std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type, const std::string &edge_type,
int64_t num_edges, int num_hops, bool is_multigraph) { int64_t num_edges, int num_hops) {
NodeFlow nf = NodeFlow::Create(); NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size(); uint64_t num_vertices = sub_vers->size();
nf->node_mapping = aten::NewIdArray(num_vertices); nf->node_mapping = aten::NewIdArray(num_vertices);
...@@ -368,9 +368,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -368,9 +368,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data); dgl_id_t *flow_off_data = static_cast<dgl_id_t *>(nf->flow_offsets->data);
dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data); dgl_id_t *edge_map_data = static_cast<dgl_id_t *>(nf->edge_mapping->data);
// Construct sub_csr_graph // Construct sub_csr_graph, we treat nodeflow as multigraph by default
// TODO(minjie): is nodeflow a multigraph? auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges));
auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges, is_multigraph));
dgl_id_t* indptr_out = static_cast<dgl_id_t*>(subg_csr->indptr()->data); dgl_id_t* indptr_out = static_cast<dgl_id_t*>(subg_csr->indptr()->data);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>(subg_csr->indices()->data); dgl_id_t* col_list_out = static_cast<dgl_id_t*>(subg_csr->indices()->data);
dgl_id_t* eid_out = static_cast<dgl_id_t*>(subg_csr->edge_ids()->data); dgl_id_t* eid_out = static_cast<dgl_id_t*>(subg_csr->edge_ids()->data);
...@@ -592,7 +591,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -592,7 +591,7 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
} }
return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos, return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos,
edge_type, num_edges, num_hops, graph->IsMultigraph()); edge_type, num_edges, num_hops);
} }
} // namespace } // namespace
...@@ -1090,7 +1089,6 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1090,7 +1089,6 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
int64_t num_tot_nodes = gptr_->NumVertices(); int64_t num_tot_nodes = gptr_->NumVertices();
if (neg_sample_size > num_tot_nodes) if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes; neg_sample_size = num_tot_nodes;
bool is_multigraph = gptr_->IsMultigraph();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2; int64_t num_pos_edges = coo->shape[0] / 2;
...@@ -1223,7 +1221,7 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1223,7 +1221,7 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
NegSubgraph neg_subg; NegSubgraph neg_subg;
// We sample negative vertices without replacement. // We sample negative vertices without replacement.
// There shouldn't be duplicated edges. // There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, is_multigraph)); COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo)); neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
...@@ -1374,7 +1372,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1374,7 +1372,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
NegSubgraph neg_subg; NegSubgraph neg_subg;
// We sample negative vertices without replacement. // We sample negative vertices without replacement.
// There shouldn't be duplicated edges. // There shouldn't be duplicated edges.
COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst, gptr_->IsMultigraph())); COOPtr neg_coo(new COO(num_neg_nodes, neg_src, neg_dst));
neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo)); neg_subg.graph = GraphPtr(new ImmutableGraph(neg_coo));
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
......
...@@ -69,16 +69,6 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -69,16 +69,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray src, IdArray dst, bool is_multigraph)
: BaseHeteroGraph(metagraph),
is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
}
COO(GraphPtr metagraph, const aten::COOMatrix& coo) COO(GraphPtr metagraph, const aten::COOMatrix& coo)
: BaseHeteroGraph(metagraph), adj_(coo) { : BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always // Data index should not be inherited. Edges in COO format are always
...@@ -142,7 +132,6 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -142,7 +132,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.num_rows, adj_.num_cols, adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits)); aten::AsNumBits(adj_.col, bits));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
...@@ -155,14 +144,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -155,14 +144,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.num_rows, adj_.num_cols, adj_.num_rows, adj_.num_cols,
adj_.row.CopyTo(ctx), adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx)); adj_.col.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
bool IsMultigraph() const override { bool IsMultigraph() const override {
return const_cast<COO*>(this)->is_multigraph_.Get([this] () { return aten::COOHasDuplicate(adj_);
return aten::COOHasDuplicate(adj_);
});
} }
bool IsReadonly() const override { bool IsReadonly() const override {
...@@ -422,9 +408,6 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -422,9 +408,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
/*! \brief internal adjacency matrix. Data array is empty */ /*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_; aten::COOMatrix adj_;
/*! \brief multi-graph flag */
Lazy<bool> is_multigraph_;
}; };
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
...@@ -448,17 +431,6 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -448,17 +431,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: BaseHeteroGraph(metagraph), is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) { : BaseHeteroGraph(metagraph), adj_(csr) {
} }
...@@ -519,7 +491,6 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -519,7 +491,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits), aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits)); aten::AsNumBits(adj_.data, bits));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
...@@ -534,15 +505,12 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -534,15 +505,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
adj_.indptr.CopyTo(ctx), adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx), adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx)); adj_.data.CopyTo(ctx));
ret.is_multigraph_ = is_multigraph_;
return ret; return ret;
} }
} }
bool IsMultigraph() const override { bool IsMultigraph() const override {
return const_cast<CSR*>(this)->is_multigraph_.Get([this] () { return aten::CSRHasDuplicate(adj_);
return aten::CSRHasDuplicate(adj_);
});
} }
bool IsReadonly() const override { bool IsReadonly() const override {
...@@ -775,9 +743,6 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -775,9 +743,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
/*! \brief internal adjacency matrix. Data array stores edge ids */ /*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
/*! \brief multi-graph flag */
Lazy<bool> is_multigraph_;
}; };
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
...@@ -1302,20 +1267,20 @@ GraphPtr UnitGraph::AsImmutableGraph() const { ...@@ -1302,20 +1267,20 @@ GraphPtr UnitGraph::AsImmutableGraph() const {
dgl::COOPtr coo_ptr = nullptr; dgl::COOPtr coo_ptr = nullptr;
if (in_csr_) { if (in_csr_) {
aten::CSRMatrix csc = GetCSCMatrix(0); aten::CSRMatrix csc = GetCSCMatrix(0);
in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data, true)); in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
} }
if (out_csr_) { if (out_csr_) {
aten::CSRMatrix csr = GetCSRMatrix(0); aten::CSRMatrix csr = GetCSRMatrix(0);
out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data, true)); out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
} }
if (coo_) { if (coo_) {
aten::COOMatrix coo = GetCOOMatrix(0); aten::COOMatrix coo = GetCOOMatrix(0);
if (!COOHasData(coo)) { if (!COOHasData(coo)) {
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col, true)); coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
} else { } else {
IdArray new_src = Scatter(coo.row, coo.data); IdArray new_src = Scatter(coo.row, coo.data);
IdArray new_dst = Scatter(coo.col, coo.data); IdArray new_dst = Scatter(coo.col, coo.data);
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst, true)); coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst));
} }
} }
return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr)); return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr));
......
...@@ -177,7 +177,7 @@ def test_nx_conversion(): ...@@ -177,7 +177,7 @@ def test_nx_conversion():
n3 = F.randn((5, 4)) n3 = F.randn((5, 4))
e1 = F.randn((4, 5)) e1 = F.randn((4, 5))
e2 = F.randn((4, 7)) e2 = F.randn((4, 7))
g = DGLGraph(multigraph=True) g = DGLGraph()
g.add_nodes(5) g.add_nodes(5)
g.add_edges([0,1,3,4], [2,4,0,3]) g.add_edges([0,1,3,4], [2,4,0,3])
g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3}) g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
...@@ -225,7 +225,7 @@ def test_nx_conversion(): ...@@ -225,7 +225,7 @@ def test_nx_conversion():
for _, _, attr in nxg.edges(data=True): for _, _, attr in nxg.edges(data=True):
attr.pop('id') attr.pop('id')
# test with a new graph # test with a new graph
g = DGLGraph(multigraph=True) g = DGLGraph()
g.from_networkx(nxg, node_attrs=['n1'], edge_attrs=['e1']) g.from_networkx(nxg, node_attrs=['n1'], edge_attrs=['e1'])
# check graph size # check graph size
assert g.number_of_nodes() == 7 assert g.number_of_nodes() == 7
...@@ -512,7 +512,7 @@ def test_pull_0deg(): ...@@ -512,7 +512,7 @@ def test_pull_0deg():
assert F.allclose(new[1], old[1]) assert F.allclose(new[1], old[1])
def test_send_multigraph(): def test_send_multigraph():
g = DGLGraph(multigraph=True) g = DGLGraph()
g.add_nodes(3) g.add_nodes(3)
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(0, 1) g.add_edge(0, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment