"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "e4674531dd54874c0abbc786ad5635c92c34dc3e"
Unverified Commit aa19df1b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix batching isolated node types (#2035)



* [Bugfix] Fix batching isolated ntypes

* fix test

* remove test.py
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent be444e52
"""Utilities for batching/unbatching graphs.""" """Utilities for batching/unbatching graphs."""
from collections.abc import Mapping from collections.abc import Mapping
from collections import defaultdict
from . import backend as F from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning
...@@ -167,24 +168,25 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): ...@@ -167,24 +168,25 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
utils.check_all_same_device(graphs, 'graphs') utils.check_all_same_device(graphs, 'graphs')
utils.check_all_same_idtype(graphs, 'graphs') utils.check_all_same_idtype(graphs, 'graphs')
relations = graphs[0].canonical_etypes relations = graphs[0].canonical_etypes
ntypes = graphs[0].ntypes
idtype = graphs[0].idtype idtype = graphs[0].idtype
device = graphs[0].device device = graphs[0].device
# Batch graph structure for each relation graph # Batch graph structure for each relation graph
edge_dict = {} edge_dict = defaultdict(list)
num_nodes_dict = {} num_nodes_dict = defaultdict(int)
for rel in relations: for g in graphs:
srctype, etype, dsttype = rel for rel in relations:
srcnid_off = dstnid_off = 0 srctype, etype, dsttype = rel
src, dst = [], []
for g in graphs:
u, v = g.edges(order='eid', etype=rel) u, v = g.edges(order='eid', etype=rel)
src.append(u + srcnid_off) src = u + num_nodes_dict[srctype]
dst.append(v + dstnid_off) dst = v + num_nodes_dict[dsttype]
srcnid_off += g.number_of_nodes(srctype) edge_dict[rel].append((src, dst))
dstnid_off += g.number_of_nodes(dsttype) for ntype in ntypes:
num_nodes_dict[ntype] += g.number_of_nodes(ntype)
for rel in relations:
src, dst = zip(*edge_dict[rel])
edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0)) edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0))
num_nodes_dict.update({srctype : srcnid_off, dsttype : dstnid_off})
retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device) retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device)
# Compute batch num nodes # Compute batch num nodes
......
import dgl import dgl
import backend as F import backend as F
import unittest import unittest
import pytest
from dgl.base import ALL from dgl.base import ALL
from utils import parametrize_dtype from utils import parametrize_dtype
from test_utils import check_graph_equal from test_utils import check_graph_equal, get_cases
def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None): def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes assert g1.ntypes == g2.ntypes
...@@ -40,19 +41,13 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N ...@@ -40,19 +41,13 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
for feat_name in edge_attrs[ety]: for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]) assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
@pytest.mark.parametrize('gs', get_cases(['two_hetero_batch']))
@parametrize_dtype @parametrize_dtype
def test_topology(idtype): def test_topology(gs, idtype):
"""Test batching two DGLHeteroGraphs where some nodes are isolated in some relations""" """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
g1 = dgl.heterograph({ g1, g2 = gs
('user', 'follows', 'user'): ([0, 1], [1, 2]), g1 = g1.astype(idtype).to(F.ctx())
('user', 'follows', 'developer'): ([0, 1], [1, 2]), g2 = g2.astype(idtype).to(F.ctx())
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
}, idtype=idtype, device=F.ctx())
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
}, idtype=idtype, device=F.ctx())
bg = dgl.batch([g1, g2]) bg = dgl.batch([g1, g2])
assert bg.idtype == idtype assert bg.idtype == idtype
......
...@@ -132,7 +132,7 @@ def _global_message_func(nodes): ...@@ -132,7 +132,7 @@ def _global_message_func(nodes):
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
@parametrize_dtype @parametrize_dtype
@pytest.mark.parametrize('g', get_cases(exclude=['dglgraph'])) @pytest.mark.parametrize('g', get_cases(exclude=['dglgraph', 'two_hetero_batch']))
def test_pickling_graph(g, idtype): def test_pickling_graph(g, idtype):
g = g.astype(idtype) g = g.astype(idtype)
new_g = _reconstruct_pickle(g) new_g = _reconstruct_pickle(g)
......
...@@ -109,3 +109,31 @@ def random_bipartite(size_src, size_dst): ...@@ -109,3 +109,31 @@ def random_bipartite(size_src, size_dst):
def random_block(size): def random_block(size):
g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1)) g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1]))) return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))
@register_case(['two_hetero_batch'])
def two_hetero_batch():
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
})
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
})
return [g1, g2]
@register_case(['two_hetero_batch'])
def two_hetero_batch_with_isolated_ntypes():
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
}, num_nodes_dict={'user': 4, 'game': 2, 'developer': 3, 'platform': 2})
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
}, num_nodes_dict={'user': 3, 'game': 2, 'developer': 3, 'platform': 3})
return [g1, g2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment