"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3d7eaf83d721ed1137ad1838b73be83c737721d4"
Unverified Commit 9628481f authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #5 from jermainewang/master

Sync with upstream
parents 026d35c5 3de20385
...@@ -523,6 +523,7 @@ class GraphIndex(object): ...@@ -523,6 +523,7 @@ class GraphIndex(object):
""" """
src, dst, eid = self.edges() src, dst, eid = self.edges()
ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph() ret = nx.MultiDiGraph() if self.is_multigraph() else nx.DiGraph()
ret.add_nodes_from(range(self.number_of_nodes()))
for u, v, id in zip(src, dst, eid): for u, v, id in zip(src, dst, eid):
ret.add_edge(u, v, id=id) ret.add_edge(u, v, id=id)
return ret return ret
...@@ -548,16 +549,20 @@ class GraphIndex(object): ...@@ -548,16 +549,20 @@ class GraphIndex(object):
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes) self.add_nodes(num_nodes)
has_edge_id = 'id' in next(iter(nx_graph.edges))
if nx_graph.number_of_edges() == 0:
return
# nx_graph.edges(data=True) returns src, dst, attr_dict
has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
if has_edge_id: if has_edge_id:
num_edges = nx_graph.number_of_edges() num_edges = nx_graph.number_of_edges()
src = np.zeros((num_edges,), dtype=np.int64) src = np.zeros((num_edges,), dtype=np.int64)
dst = np.zeros((num_edges,), dtype=np.int64) dst = np.zeros((num_edges,), dtype=np.int64)
for e, attr in nx_graph.edges.items: for u, v, attr in nx_graph.edges(data=True):
# MultiDiGraph returns a triplet in e while DiGraph returns a pair
eid = attr['id'] eid = attr['id']
src[eid] = e[0] src[eid] = u
dst[eid] = e[1] dst[eid] = v
else: else:
src = [] src = []
dst = [] dst = []
......
...@@ -24,7 +24,7 @@ DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) { ...@@ -24,7 +24,7 @@ DLManagedTensor* CreateTmpDLManagedTensor(const TVMArgValue& arg) {
PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) { PackedFunc ConvertNDArrayVectorToPackedFunc(const std::vector<NDArray>& vec) {
auto body = [vec](TVMArgs args, TVMRetValue* rv) { auto body = [vec](TVMArgs args, TVMRetValue* rv) {
size_t which = args[0]; int which = args[0];
if (which >= vec.size()) { if (which >= vec.size()) {
LOG(FATAL) << "invalid choice"; LOG(FATAL) << "invalid choice";
} else { } else {
......
...@@ -94,6 +94,27 @@ def test_nx(): ...@@ -94,6 +94,27 @@ def test_nx():
assert 0 in gi.edge_id(0, 1) assert 0 in gi.edge_id(0, 1)
assert 1 in gi.edge_id(0, 1) assert 1 in gi.edge_id(0, 1)
nxg = nx.DiGraph()
nxg.add_nodes_from(range(3))
gi = create_graph_index(nxg)
assert gi.number_of_nodes() == 3
assert gi.number_of_edges() == 0
gi = create_graph_index()
gi.add_nodes(3)
nxg = gi.to_networkx()
assert len(nxg.nodes) == 3
assert len(nxg.edges) == 0
nxg = nx.DiGraph()
nxg.add_edge(0, 1, id=0)
nxg.add_edge(1, 2, id=1)
gi = create_graph_index(nxg)
assert 0 in gi.edge_id(0, 1)
assert 1 in gi.edge_id(1, 2)
assert gi.number_of_edges() == 2
assert gi.number_of_nodes() == 3
def test_predsucc(): def test_predsucc():
gi = create_graph_index(multigraph=True) gi = create_graph_index(multigraph=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment