"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "94ab9709a7d0512bbcc5a181e38bab09045cc28e"
Unverified Commit f0fafa20 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug fix] Fix batch information with remove_nodes/edges applied (#3119)



* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix lint

* Fix

* Update

* Fix test cases

* Fix

* add docstrings

* Update
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-17.us-west-2.compute.internal>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7c3e1f94
...@@ -9,6 +9,7 @@ import networkx as nx ...@@ -9,6 +9,7 @@ import networkx as nx
import numpy as np import numpy as np
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .ops import segment
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
from . import core from . import core
from . import graph_index from . import graph_index
...@@ -556,11 +557,7 @@ class DGLHeteroGraph(object): ...@@ -556,11 +557,7 @@ class DGLHeteroGraph(object):
Notes Notes
----- -----
This function preserves the batch information.
This function discards the batch information. Please use
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
Examples Examples
-------- --------
...@@ -582,6 +579,17 @@ class DGLHeteroGraph(object): ...@@ -582,6 +579,17 @@ class DGLHeteroGraph(object):
>>> g.edata['he'] >>> g.edata['he']
tensor([[2.]]) tensor([[2.]])
Removing edges from a batched graph preserves batch information.
>>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))
>>> g2 = dgl.graph((torch.tensor([1, 2, 3]), torch.tensor([1, 3, 4])))
>>> bg = dgl.batch([g, g2])
>>> bg.batch_num_edges()
tensor([3, 3])
>>> bg.remove_edges([1, 4])
>>> bg.batch_num_edges()
tensor([2, 2])
**Heterogeneous Graphs with Multiple Edge Types** **Heterogeneous Graphs with Multiple Edge Types**
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
...@@ -627,6 +635,19 @@ class DGLHeteroGraph(object): ...@@ -627,6 +635,19 @@ class DGLHeteroGraph(object):
else: else:
edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype) edges[c_etype] = self.edges(form='eid', order='eid', etype=c_etype)
# If the graph is batched, update batch_num_edges
batched = self._batch_num_edges is not None
if batched:
c_etype = (u_type, e_type, v_type)
one_hot_removed_edges = F.zeros((self.num_edges(c_etype),), F.float32, self.device)
one_hot_removed_edges = F.scatter_row(one_hot_removed_edges, eids,
F.full_1d(len(eids), 1., F.float32, self.device))
c_etype_batch_num_edges = self._batch_num_edges[c_etype]
batch_num_removed_edges = segment.segment_reduce(c_etype_batch_num_edges,
one_hot_removed_edges, reducer='sum')
self._batch_num_edges[c_etype] = c_etype_batch_num_edges - \
F.astype(batch_num_removed_edges, F.int64)
sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids) sub_g = self.edge_subgraph(edges, relabel_nodes=False, store_ids=store_ids)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
...@@ -655,11 +676,7 @@ class DGLHeteroGraph(object): ...@@ -655,11 +676,7 @@ class DGLHeteroGraph(object):
Notes Notes
----- -----
This function preserves the batch information.
This function discards the batch information. Please use
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
Examples Examples
-------- --------
...@@ -682,6 +699,19 @@ class DGLHeteroGraph(object): ...@@ -682,6 +699,19 @@ class DGLHeteroGraph(object):
>>> g.edata['he'] >>> g.edata['he']
tensor([[2.]]) tensor([[2.]])
Removing nodes from a batched graph preserves batch information.
>>> g = dgl.graph((torch.tensor([0, 0, 2]), torch.tensor([0, 1, 2])))
>>> g2 = dgl.graph((torch.tensor([1, 2, 3]), torch.tensor([1, 3, 4])))
>>> bg = dgl.batch([g, g2])
>>> bg.batch_num_nodes()
tensor([3, 5])
>>> bg.remove_nodes([1, 4])
>>> bg.batch_num_nodes()
tensor([2, 4])
>>> bg.batch_num_edges()
tensor([2, 2])
**Heterogeneous Graphs with Multiple Node Types** **Heterogeneous Graphs with Multiple Node Types**
>>> g = dgl.heterograph({ >>> g = dgl.heterograph({
...@@ -725,17 +755,61 @@ class DGLHeteroGraph(object): ...@@ -725,17 +755,61 @@ class DGLHeteroGraph(object):
nodes = {} nodes = {}
for c_ntype in self.ntypes: for c_ntype in self.ntypes:
if self.get_ntype_id(c_ntype) == ntid: if self.get_ntype_id(c_ntype) == ntid:
target_ntype = c_ntype
original_nids = self.nodes(c_ntype) original_nids = self.nodes(c_ntype)
nodes[c_ntype] = utils.compensate(nids, original_nids) nodes[c_ntype] = utils.compensate(nids, original_nids)
else: else:
nodes[c_ntype] = self.nodes(c_ntype) nodes[c_ntype] = self.nodes(c_ntype)
# If the graph is batched, update batch_num_nodes
batched = self._batch_num_nodes is not None
if batched:
one_hot_removed_nodes = F.zeros((self.num_nodes(target_ntype),),
F.float32, self.device)
one_hot_removed_nodes = F.scatter_row(one_hot_removed_nodes, nids,
F.full_1d(len(nids), 1., F.float32, self.device))
c_ntype_batch_num_nodes = self._batch_num_nodes[target_ntype]
batch_num_removed_nodes = segment.segment_reduce(
c_ntype_batch_num_nodes, one_hot_removed_nodes, reducer='sum')
self._batch_num_nodes[target_ntype] = c_ntype_batch_num_nodes - \
F.astype(batch_num_removed_nodes, F.int64)
# Record old num_edges to check later whether some edges were removed
old_num_edges = {c_etype: self._graph.number_of_edges(self.get_etype_id(c_etype))
for c_etype in self.canonical_etypes}
# node_subgraph # node_subgraph
sub_g = self.subgraph(nodes, store_ids=store_ids) # If batch_num_edges is to be updated, record the original edge IDs
sub_g = self.subgraph(nodes, store_ids=store_ids or batched)
self._graph = sub_g._graph self._graph = sub_g._graph
self._node_frames = sub_g._node_frames self._node_frames = sub_g._node_frames
self._edge_frames = sub_g._edge_frames self._edge_frames = sub_g._edge_frames
# If the graph is batched, update batch_num_edges
if batched:
canonical_etypes = [
c_etype for c_etype in self.canonical_etypes if
self._graph.number_of_edges(self.get_etype_id(c_etype)) != old_num_edges[c_etype]]
for c_etype in canonical_etypes:
if self._graph.number_of_edges(self.get_etype_id(c_etype)) == 0:
self._batch_num_edges[c_etype] = F.zeros(
(self.batch_size,), F.int64, self.device)
continue
one_hot_left_edges = F.zeros((old_num_edges[c_etype],), F.float32, self.device)
eids = self.edges[c_etype].data[EID]
one_hot_left_edges = F.scatter_row(one_hot_left_edges, eids,
F.full_1d(len(eids), 1., F.float32, self.device))
batch_num_left_edges = segment.segment_reduce(
self._batch_num_edges[c_etype], one_hot_left_edges, reducer='sum')
self._batch_num_edges[c_etype] = F.astype(batch_num_left_edges, F.int64)
if batched and not store_ids:
for c_ntype in self.ntypes:
self.nodes[c_ntype].data.pop(NID)
for c_etype in self.canonical_etypes:
self.edges[c_etype].data.pop(EID)
def _reset_cached_info(self): def _reset_cached_info(self):
"""Some info like batch_num_nodes may be stale after mutation """Some info like batch_num_nodes may be stale after mutation
Clean these cached info Clean these cached info
...@@ -5378,6 +5452,10 @@ class DGLHeteroGraph(object): ...@@ -5378,6 +5452,10 @@ class DGLHeteroGraph(object):
ret._node_frames = [fr.clone() for fr in self._node_frames] ret._node_frames = [fr.clone() for fr in self._node_frames]
ret._edge_frames = [fr.clone() for fr in self._edge_frames] ret._edge_frames = [fr.clone() for fr in self._edge_frames]
# Copy the batch information
ret._batch_num_nodes = copy.copy(self._batch_num_nodes)
ret._batch_num_edges = copy.copy(self._batch_num_edges)
return ret return ret
def local_var(self): def local_var(self):
......
...@@ -6,7 +6,6 @@ import dgl ...@@ -6,7 +6,6 @@ import dgl
import dgl.function as fn import dgl.function as fn
import dgl.partition import dgl.partition
import backend as F import backend as F
from dgl.graph_index import from_scipy_sparse_matrix
import unittest import unittest
from utils import parametrize_dtype from utils import parametrize_dtype
...@@ -1298,6 +1297,84 @@ def test_remove_edges(idtype): ...@@ -1298,6 +1297,84 @@ def test_remove_edges(idtype):
assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2], dtype=idtype)) assert F.array_equal(g.nodes['game'].data['h'], F.tensor([2, 2], dtype=idtype))
assert F.array_equal(g.nodes['developer'].data['h'], F.tensor([3, 3], dtype=idtype)) assert F.array_equal(g.nodes['developer'].data['h'], F.tensor([3, 3], dtype=idtype))
# batched graph
ctx = F.ctx()
g1 = dgl.graph(([0, 1], [1, 2]), num_nodes=5, idtype=idtype, device=ctx)
g2 = dgl.graph(([], []), idtype=idtype, device=ctx)
g3 = dgl.graph(([2, 3, 4], [3, 2, 1]), idtype=idtype, device=ctx)
bg = dgl.batch([g1, g2, g3])
bg_r = dgl.remove_edges(bg, 2)
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([2, 0, 2], dtype=F.int64))
bg_r = dgl.remove_edges(bg, [0, 2])
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64))
bg_r = dgl.remove_edges(bg, F.tensor([0, 2], dtype=idtype))
assert bg.batch_size == bg_r.batch_size
assert F.array_equal(bg.batch_num_nodes(), bg_r.batch_num_nodes())
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([1, 0, 2], dtype=F.int64))
# batched heterogeneous graph
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([1, 3], [0, 1])
}, num_nodes_dict={'user': 4, 'game': 3}, idtype=idtype, device=ctx)
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 2], [3, 4]),
('user', 'plays', 'game'): ([], [])
}, num_nodes_dict={'user': 6, 'game': 2}, idtype=idtype, device=ctx)
g3 = dgl.heterograph({
('user', 'follows', 'user'): ([], []),
('user', 'plays', 'game'): ([1, 2], [1, 2])
}, idtype=idtype, device=ctx)
bg = dgl.batch([g1, g2, g3])
bg_r = dgl.remove_edges(bg, 1, etype='follows')
assert bg.batch_size == bg_r.batch_size
ntypes = bg.ntypes
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([1, 2, 0], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges('plays'), bg.batch_num_edges('plays'))
bg_r = dgl.remove_edges(bg, 2, etype='plays')
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([2, 0, 1], dtype=F.int64))
bg_r = dgl.remove_edges(bg, [0, 1, 3], etype='follows')
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64))
assert F.array_equal(bg.batch_num_edges('plays'), bg_r.batch_num_edges('plays'))
bg_r = dgl.remove_edges(bg, [1, 2], etype='plays')
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
bg_r = dgl.remove_edges(bg, F.tensor([0, 1, 3], dtype=idtype), etype='follows')
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64))
assert F.array_equal(bg.batch_num_edges('plays'), bg_r.batch_num_edges('plays'))
bg_r = dgl.remove_edges(bg, F.tensor([1, 2], dtype=idtype), etype='plays')
assert bg.batch_size == bg_r.batch_size
for nty in ntypes:
assert F.array_equal(bg.batch_num_nodes(nty), bg_r.batch_num_nodes(nty))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
@parametrize_dtype @parametrize_dtype
def test_remove_nodes(idtype): def test_remove_nodes(idtype):
# homogeneous Graphs # homogeneous Graphs
...@@ -1396,6 +1473,83 @@ def test_remove_nodes(idtype): ...@@ -1396,6 +1473,83 @@ def test_remove_nodes(idtype):
assert F.array_equal(u, F.tensor([1], dtype=idtype)) assert F.array_equal(u, F.tensor([1], dtype=idtype))
assert F.array_equal(v, F.tensor([0], dtype=idtype)) assert F.array_equal(v, F.tensor([0], dtype=idtype))
# batched graph
ctx = F.ctx()
g1 = dgl.graph(([0, 1], [1, 2]), num_nodes=5, idtype=idtype, device=ctx)
g2 = dgl.graph(([], []), idtype=idtype, device=ctx)
g3 = dgl.graph(([2, 3, 4], [3, 2, 1]), idtype=idtype, device=ctx)
bg = dgl.batch([g1, g2, g3])
bg_r = dgl.remove_nodes(bg, 1)
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 5], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 3], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, [1, 7])
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, F.tensor([1, 7], dtype=idtype))
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes(), F.tensor([4, 0, 4], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges(), F.tensor([0, 0, 1], dtype=F.int64))
# batched heterogeneous graph
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([1, 3], [0, 1])
}, num_nodes_dict={'user': 4, 'game': 3}, idtype=idtype, device=ctx)
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 2], [3, 4]),
('user', 'plays', 'game'): ([], [])
}, num_nodes_dict={'user': 6, 'game': 2}, idtype=idtype, device=ctx)
g3 = dgl.heterograph({
('user', 'follows', 'user'): ([], []),
('user', 'plays', 'game'): ([1, 2], [1, 2])
}, idtype=idtype, device=ctx)
bg = dgl.batch([g1, g2, g3])
bg_r = dgl.remove_nodes(bg, 1, ntype='user')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 6, 3], dtype=F.int64))
assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game'))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 2, 0], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 2], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, 6, ntype='game')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user'))
assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([3, 2, 2], dtype=F.int64))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([2, 0, 1], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, [1, 5, 6, 11], ntype='user')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 4, 2], dtype=F.int64))
assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game'))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, [0, 3, 4, 7], ntype='game')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user'))
assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([2, 0, 2], dtype=F.int64))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, F.tensor([1, 5, 6, 11], dtype=idtype), ntype='user')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg_r.batch_num_nodes('user'), F.tensor([3, 4, 2], dtype=F.int64))
assert F.array_equal(bg.batch_num_nodes('game'), bg_r.batch_num_nodes('game'))
assert F.array_equal(bg_r.batch_num_edges('follows'), F.tensor([0, 1, 0], dtype=F.int64))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
bg_r = dgl.remove_nodes(bg, F.tensor([0, 3, 4, 7], dtype=idtype), ntype='game')
assert bg_r.batch_size == bg.batch_size
assert F.array_equal(bg.batch_num_nodes('user'), bg_r.batch_num_nodes('user'))
assert F.array_equal(bg_r.batch_num_nodes('game'), F.tensor([2, 0, 2], dtype=F.int64))
assert F.array_equal(bg.batch_num_edges('follows'), bg_r.batch_num_edges('follows'))
assert F.array_equal(bg_r.batch_num_edges('plays'), F.tensor([1, 0, 1], dtype=F.int64))
@parametrize_dtype @parametrize_dtype
def test_add_selfloop(idtype): def test_add_selfloop(idtype):
# homogeneous graph # homogeneous graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment