Commit bfe3e9d6 authored by Zihao Ye's avatar Zihao Ye Committed by Quan (Andy) Gan
Browse files

[Feature] Add switch readonly state interface (#389)

* fix rgcn tutorial

* small fix

* add switch readonly state interface

* remove trailing ws

* fix

* fix

* change signature

* change signature

* remove trailing ws

* upd

* upd interface

* upd

* rm trailing ws

* reformat

* trailing space

* upd
parent 111a4aa7
......@@ -3043,6 +3043,38 @@ class DGLGraph(DGLBaseGraph):
edges = F.tensor(edges)
return F.boolean_mask(edges, e_mask)
def readonly(self, readonly_state=True):
"""Set this graph's readonly state in-place.
Parameters
----------
readonly_state : bool, optional
New readonly state of the graph, defaults to True.
Examples
--------
>>> G = dgl.DGLGraph()
>>> G.add_nodes(3)
>>> G.add_edge(0, 1)
>>> G.readonly()
>>> try:
>>> G.add_nodes(5)
>>> fail = False
>>> except:
>>> fail = True
>>>
>>> fail
True
>>> G.readonly(False)
>>> G.add_nodes(5)
>>> G.number_of_nodes()
8
"""
if readonly_state == self._graph.is_readonly():
return self
self._graph.readonly(readonly_state)
return self
def __repr__(self):
ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n'
' ndata_schemes={ndata}\n'
......
......@@ -54,7 +54,6 @@ class GraphIndex(object):
self._init(src, dst, utils.toindex(F.arange(0, len(src))), n_nodes)
else:
self._handle = _CAPI_DGLGraphCreateMutable(multigraph)
self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)
......@@ -139,6 +138,19 @@ class GraphIndex(object):
self._readonly = bool(_CAPI_DGLGraphIsReadonly(self._handle))
return self._readonly
def readonly(self, readonly_state=True):
"""Set the readonly state of graph index in-place.
Parameters
----------
readonly_state : bool
New readonly state of current graph index.
"""
n_nodes, multigraph, _, src, dst = self.__getstate__()
self.clear_cache()
state = (n_nodes, multigraph, readonly_state, src, dst)
self.__setstate__(state)
def number_of_nodes(self):
"""Return the number of nodes.
......
......@@ -5,6 +5,7 @@ import scipy.sparse as sp
import networkx as nx
import dgl
import backend as F
from dgl import DGLError
def test_graph_creation():
g = dgl.DGLGraph()
......@@ -126,9 +127,67 @@ def test_incmat_cache():
inc4 = g.incidence_matrix("in")
assert id(inc4) != id(inc35)
def test_readonly():
g = dgl.DGLGraph()
g.add_nodes(5)
g.add_edges([0, 1, 2, 3], [1, 2, 3, 4])
g.ndata['x'] = F.zeros((5, 3))
g.edata['x'] = F.zeros((4, 4))
g.readonly(False)
assert g._graph.is_readonly() == False
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
g.readonly()
assert g._graph.is_readonly() == True
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
try:
g.add_nodes(5)
fail = False
except DGLError:
fail = True
finally:
assert fail
g.readonly()
assert g._graph.is_readonly() == True
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
try:
g.add_nodes(5)
fail = False
except DGLError:
fail = True
finally:
assert fail
g.readonly(False)
assert g._graph.is_readonly() == False
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
try:
g.add_nodes(10)
g.add_edges([4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
[5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
fail = False
except DGLError:
fail = True
finally:
assert not fail
assert g.number_of_nodes() == 15
assert F.shape(g.ndata['x']) == (15, 3)
assert g.number_of_edges() == 14
assert F.shape(g.edata['x']) == (14, 4)
if __name__ == '__main__':
test_graph_creation()
test_create_from_elist()
test_adjmat_cache()
test_incmat()
test_incmat_cache()
test_readonly()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment