"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "42d950174f5da973d3d35e55d3e1e49edf87a35b"
Unverified Commit a936f9d9 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Hetero] Allow Batching Graphs with Zero Nodes/Edges (#1575)

* Allow batching graphs with zero nodes and edges

* Fix the case without node/edge features

* Fix lint
parent e979685a
...@@ -134,10 +134,9 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph): ...@@ -134,10 +134,9 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
""" """
def __init__(self, graph_list, node_attrs, edge_attrs): def __init__(self, graph_list, node_attrs, edge_attrs):
# Sanity check. Make sure all graphs have the same node/edge types, in the same order. # Sanity check. Make sure all graphs have the same node/edge types, in the same order.
ref_graph = graph_list[0] ref_canonical_etypes = graph_list[0].canonical_etypes
ref_canonical_etypes = ref_graph.canonical_etypes ref_ntypes = graph_list[0].ntypes
ref_ntypes = ref_graph.ntypes ref_etypes = graph_list[0].etypes
ref_etypes = ref_graph.etypes
for i in range(1, len(graph_list)): for i in range(1, len(graph_list)):
g_i = graph_list[i] g_i = graph_list[i]
assert g_i.ntypes == ref_ntypes, \ assert g_i.ntypes == ref_ntypes, \
...@@ -145,40 +144,66 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph): ...@@ -145,40 +144,66 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
assert g_i.canonical_etypes == ref_canonical_etypes, \ assert g_i.canonical_etypes == ref_canonical_etypes, \
'The canonical edge types of graph {:d} and {:d} should be the same.'.format(0, i) 'The canonical edge types of graph {:d} and {:d} should be the same.'.format(0, i)
# Sanity check. Make sure all graphs have same # Sanity check. Make sure all graphs have same node/edge features in terns of name, size
# node/edge features in terns of name and size. # and dtype if the number of nodes is nonzero.
ref_node_feats = dict()
for nty in ref_ntypes: for nty in ref_ntypes:
ref_feats_nty = set(ref_graph.node_attr_schemes(nty).keys()) for i, graph in enumerate(graph_list):
for i in range(1, len(graph_list)): # No nodes, skip it
assert ref_feats_nty == set(graph_list[i].node_attr_schemes(nty).keys()), \ if graph.number_of_nodes(nty) == 0:
'The node features of graph {:d} and {:d} for ' \ continue
'node type {} should be the same.'.format(0, i, nty) # Use this for reference of feature names, shape and dtype
for nfeats in ref_feats_nty: if nty not in ref_node_feats:
assert ref_graph.node_attr_schemes(nty)[nfeats] == \ ref_node_feats[nty] = (i, graph.node_attr_schemes(nty))
graph_list[i].node_attr_schemes(nty)[nfeats], \ continue
'For graph {:d} and {:d}, the size and dtype for feature ' \ # Name check
'{} of {}-typed nodes should be the same.'.format(0, i, nfeats, nty) assert set(ref_node_feats[nty][1].keys()) == \
set(graph.node_attr_schemes(nty).keys()), \
'The node features of graph {:d} and {:d} for node type {} should be the ' \
'same.'.format(ref_node_feats[nty][0], i, nty)
# Size and dtype check
for nfeats in ref_node_feats[nty][1].keys():
assert ref_node_feats[nty][1][nfeats] == \
graph.node_attr_schemes(nty)[nfeats], \
'For graph {:d} and {:d}, the size and dtype for feature {} of ' \
'{}-typed nodes should be the same.'.format(
ref_node_feats[nty][0], i, nfeats, nty)
ref_edge_feats = dict()
for ety in ref_canonical_etypes: for ety in ref_canonical_etypes:
ref_feats_ety = set(ref_graph.edge_attr_schemes(ety).keys()) for i, graph in enumerate(graph_list):
for i in range(1, len(graph_list)): # No edges, skip it
assert ref_feats_ety == set(graph_list[i].edge_attr_schemes(ety).keys()), \ if graph.number_of_edges(ety) == 0:
'The edge features of graph {:d} and {:d} for ' \ continue
'edge type {} should be the same.'.format(0, i, ety) # Use this for reference of feature names, shape and dtype
for efeats in ref_feats_ety: if ety not in ref_edge_feats:
assert ref_graph.edge_attr_schemes(ety)[efeats] == \ ref_edge_feats[ety] = (i, graph.edge_attr_schemes(ety))
graph_list[i].edge_attr_schemes(ety)[efeats], \ continue
'For graph {:d} and {:d}, the size and dtype for feature ' \ # Name check
'{} of {}-typed edge should be the same.'.format(0, i, efeats, ety) assert set(ref_edge_feats[ety][1].keys()) == \
set(graph.edge_attr_schemes(ety).keys()), \
'The edge features of graph {:d} and {:d} for edge type {} should be the ' \
'same.'.format(ref_edge_feats[ety][0], i, ety)
# Size and dtype check
for efeats in ref_edge_feats[ety][1].keys():
assert ref_edge_feats[ety][1][efeats] == \
graph.edge_attr_schemes(ety)[efeats], \
'For graph {:d} and {:d}, the size and dtype for feature {} of ' \
'{}-typed edges should be the same.'.format(
ref_edge_feats[ety][0], i, efeats, ety)
def _init_attrs(types, attrs, mode): def _init_attrs(types, attrs, mode):
formatted_attrs = {t: [] for t in types} formatted_attrs = {t: [] for t in types}
if is_all(attrs): if is_all(attrs):
for typ in types: for typ in types:
if mode == 'node': if mode == 'node':
formatted_attrs[typ] = list(ref_graph.node_attr_schemes(typ).keys()) # Handle the case where the nodes of a type have no features
formatted_attrs[typ] = list(ref_node_feats.get(
typ, (None, dict()))[1].keys())
elif mode == 'edge': elif mode == 'edge':
formatted_attrs[typ] = list(ref_graph.edge_attr_schemes(typ).keys()) # Handle the case where the edges of a type have no features
formatted_attrs[typ] = list(ref_edge_feats.get(
typ, (None, dict()))[1].keys())
elif isinstance(attrs, dict): elif isinstance(attrs, dict):
for typ, v in attrs.items(): for typ, v in attrs.items():
if isinstance(v, str): if isinstance(v, str):
......
...@@ -25,11 +25,15 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N ...@@ -25,11 +25,15 @@ def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=N
if node_attrs is not None: if node_attrs is not None:
for nty in node_attrs.keys(): for nty in node_attrs.keys():
if g1.number_of_nodes(nty) == 0:
continue
for feat_name in node_attrs[nty]: for feat_name in node_attrs[nty]:
assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]) assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name])
if edge_attrs is not None: if edge_attrs is not None:
for ety in edge_attrs.keys(): for ety in edge_attrs.keys():
if g1.number_of_edges(ety) == 0:
continue
for feat_name in edge_attrs[ety]: for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]) assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
...@@ -211,7 +215,61 @@ def test_batched_features(index_dtype): ...@@ -211,7 +215,61 @@ def test_batched_features(index_dtype):
node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']}, node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
edge_attrs={('user', 'follows', 'user'): ['h1']}) edge_attrs={('user', 'follows', 'user'): ['h1']})
@parametrize_dtype
def test_batching_with_zero_nodes_edges(index_dtype):
"""Test the features of batched DGLHeteroGraphs"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): []
}, index_dtype=index_dtype)
g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
}, index_dtype=index_dtype)
g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g2.nodes['game'].data['h1'] = F.tensor([[0.]])
g2.nodes['game'].data['h2'] = F.tensor([[1.]])
g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])
bg = dgl.batch_hetero([g1, g2])
assert F.allclose(bg.nodes['user'].data['h1'],
F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
assert F.allclose(bg.nodes['user'].data['h2'],
F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1'])
assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2'])
assert F.allclose(bg.edges['follows'].data['h1'],
F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1'])
# Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(
g1, g3,
node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
edge_attrs={('user', 'follows', 'user'): ['h1']})
check_equivalence_between_heterographs(
g2, g4,
node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
edge_attrs={('user', 'follows', 'user'): ['h1']})
# Test graphs without edges
g1 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(0, 4))
g2 = dgl.bipartite([], 'u', 'r', 'v', num_nodes=(1, 5))
g2.nodes['u'].data['x'] = F.tensor([1])
dgl.batch_hetero([g1, g2])
if __name__ == '__main__': if __name__ == '__main__':
test_batching_hetero_topology() test_batching_hetero_topology()
test_batching_hetero_and_batched_hetero_topology() test_batching_hetero_and_batched_hetero_topology()
test_batched_features() test_batched_features()
test_batching_with_zero_nodes_edges()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment