Unverified Commit aa19df1b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix batching isolated node types (#2035)



* [Bugfix] Fix batching isolated ntypes

* fix test

* remove test.py
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent be444e52
"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from collections import defaultdict
from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning
......@@ -167,24 +168,25 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
utils.check_all_same_device(graphs, 'graphs')
utils.check_all_same_idtype(graphs, 'graphs')
relations = graphs[0].canonical_etypes
ntypes = graphs[0].ntypes
idtype = graphs[0].idtype
device = graphs[0].device
# Batch graph structure for each relation graph
edge_dict = {}
num_nodes_dict = {}
edge_dict = defaultdict(list)
num_nodes_dict = defaultdict(int)
for g in graphs:
for rel in relations:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
src, dst = [], []
for g in graphs:
u, v = g.edges(order='eid', etype=rel)
src.append(u + srcnid_off)
dst.append(v + dstnid_off)
srcnid_off += g.number_of_nodes(srctype)
dstnid_off += g.number_of_nodes(dsttype)
src = u + num_nodes_dict[srctype]
dst = v + num_nodes_dict[dsttype]
edge_dict[rel].append((src, dst))
for ntype in ntypes:
num_nodes_dict[ntype] += g.number_of_nodes(ntype)
for rel in relations:
src, dst = zip(*edge_dict[rel])
edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0))
num_nodes_dict.update({srctype : srcnid_off, dsttype : dstnid_off})
retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device)
# Compute batch num nodes
......
import dgl
import backend as F
import unittest
import pytest
from dgl.base import ALL
from utils import parametrize_dtype
from test_utils import check_graph_equal
from test_utils import check_graph_equal, get_cases
def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes
......@@ -40,19 +41,13 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
@pytest.mark.parametrize('gs', get_cases(['two_hetero_batch']))
@parametrize_dtype
def test_topology(idtype):
def test_topology(gs, idtype):
"""Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
}, idtype=idtype, device=F.ctx())
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
}, idtype=idtype, device=F.ctx())
g1, g2 = gs
g1 = g1.astype(idtype).to(F.ctx())
g2 = g2.astype(idtype).to(F.ctx())
bg = dgl.batch([g1, g2])
assert bg.idtype == idtype
......
......@@ -132,7 +132,7 @@ def _global_message_func(nodes):
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(exclude=['dglgraph']))
@pytest.mark.parametrize('g', get_cases(exclude=['dglgraph', 'two_hetero_batch']))
def test_pickling_graph(g, idtype):
g = g.astype(idtype)
new_g = _reconstruct_pickle(g)
......
......@@ -109,3 +109,31 @@ def random_bipartite(size_src, size_dst):
def random_block(size):
g = dgl.from_networkx(nx.erdos_renyi_graph(size, 0.1))
return dgl.to_block(g, np.unique(F.zerocopy_to_numpy(g.edges()[1])))
@register_case(['two_hetero_batch'])
def two_hetero_batch():
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
})
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
})
return [g1, g2]
@register_case(['two_hetero_batch'])
def two_hetero_batch_with_isolated_ntypes():
g1 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 3], [0, 0, 1, 1])
}, num_nodes_dict={'user': 4, 'game': 2, 'developer': 3, 'platform': 2})
g2 = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'follows', 'developer'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2], [0, 0, 1])
}, num_nodes_dict={'user': 3, 'game': 2, 'developer': 3, 'platform': 3})
return [g1, g2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment