Unverified Commit ead64de9 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Hetero] Batching/Unbatching DGLHeteroGraph (#1017)

* Update

* Update

* Update

* Fix

* CI style fix

* CI fix style

* Fix

* Try CI

* Fix test

* Update

* Update

* Update

* Update
parent 85c9ff01
.. _apibatch_heterograph:
BatchedDGLHeteroGraph -- Enable batched graph operations for heterographs
=========================================================================
.. currentmodule:: dgl
.. autoclass:: BatchedDGLHeteroGraph
Merge and decompose
-------------------
.. autosummary::
:toctree: ../../generated/
batch_hetero
unbatch_hetero
Query batch summary
----------------------
.. autosummary::
:toctree: ../../generated/
BatchedDGLHeteroGraph.batch_size
BatchedDGLHeteroGraph.batch_num_nodes
BatchedDGLHeteroGraph.batch_num_edges
...@@ -8,6 +8,7 @@ API Reference ...@@ -8,6 +8,7 @@ API Reference
heterograph heterograph
init init
batch batch
batch_heterograph
function function
traversal traversal
propagate propagate
......
...@@ -16,6 +16,7 @@ from ._ffi.base import DGLError, __version__ ...@@ -16,6 +16,7 @@ from ._ffi.base import DGLError, __version__
from .base import ALL, NTYPE, NID, ETYPE, EID from .base import ALL, NTYPE, NID, ETYPE, EID
from .backend import load_backend from .backend import load_backend
from .batched_graph import * from .batched_graph import *
from .batched_heterograph import *
from .convert import * from .convert import *
from .graph import DGLGraph from .graph import DGLGraph
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
......
...@@ -163,7 +163,6 @@ class BatchedDGLGraph(DGLGraph): ...@@ -163,7 +163,6 @@ class BatchedDGLGraph(DGLGraph):
# Check if all the graphs with mode items have the same associated features. # Check if all the graphs with mode items have the same associated features.
if len(attrs) > 0: if len(attrs) > 0:
for i, g in enumerate(graph_list): for i, g in enumerate(graph_list):
g = graph_list[i]
g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode) g_num_items, g_attrs = _get_num_item_and_attr_types(g, mode)
if g_attrs != attrs and g_num_items > 0: if g_attrs != attrs and g_num_items > 0:
raise ValueError('Expect graph {0} and {1} to have the same {2} ' raise ValueError('Expect graph {0} and {1} to have the same {2} '
...@@ -345,7 +344,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL): ...@@ -345,7 +344,7 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
Returns Returns
------- -------
BatchedDGLGraph BatchedDGLGraph
one single batched graph One single batched graph
See Also See Also
-------- --------
......
"""Classes and functions for batching multiple heterographs together."""
from collections.abc import Iterable
from . import backend as F
from . import heterograph_index
from .base import ALL, is_all
from .frame import FrameRef, Frame
from .heterograph import DGLHeteroGraph
__all__ = ['BatchedDGLHeteroGraph', 'unbatch_hetero', 'batch_hetero']
class BatchedDGLHeteroGraph(DGLHeteroGraph):
"""Class for batched DGLHeteroGraphs.
A :class:`BatchedDGLHeteroGraph` basically merges a list of small graphs into a giant
graph so that one can perform message passing and readout over a batch of graphs
simultaneously.
For a given node/edge type, the nodes/edges are re-indexed with a new id in the
batched graph with the rule below:
====== ========== ======================== === ==========================
item Graph 1 Graph 2 ... Graph k
====== ========== ======================== === ==========================
raw id 0, ..., N1 0, ..., N2 ... ..., Nk
new id 0, ..., N1 N1 + 1, ..., N1 + N2 + 1 ... ..., N1 + ... + Nk + k - 1
====== ========== ======================== === ==========================
To modify the features in :class:`BatchedDGLHeteroGraph` has no effect on the original
graphs. See the examples below about how to work around.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLHeteroGraph` to be batched.
node_attrs : None or dict
The node attributes to be batched. If ``None``, the resulted graph will not have
features. If ``dict``, it maps str to str or iterable. The keys represent names of
node types and the values represent the node features to be batched for the
corresponding type. By default, we use all features for all types of nodes.
edge_attrs : None or dict
Same as for the case of :attr:`node_attrs`.
Examples
--------
>>> import dgl
>>> import torch as th
**Example 1**
We start with a simple example.
>>> # Create the first graph and set features for nodes of type 'user'
>>> g1 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0), (1, 0)]})
>>> g1.nodes['user'].data['h1'] = th.tensor([[0.], [1.]])
>>> # Create the second graph and set features for nodes of type 'user'
>>> g2 = dgl.heterograph({('user', 'plays', 'game'): [(0, 0)]})
>>> g2.nodes['user'].data['h1'] = th.tensor([[0.]])
>>> # Batch the graphs
>>> bg = dgl.batch_hetero([g1, g2])
With the batching operation, the nodes and edges are re-indexed.
>>> bg.nodes('user')
tensor([0, 1, 2])
By default, we also copy and concatenate all the node and edge features.
>>> bg.nodes['user'].data['h1']
tensor([[0.],
[1.],
[0.]])
**Example 2**
We will now see a more complex example and the
various operations one can play with a batched graph.
>>> g1 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0)]
... })
>>> g1.nodes['user'].data['h1'] = th.tensor([[0.], [1.], [2.]])
>>> g1.nodes['user'].data['h2'] = th.tensor([[3.], [4.], [5.]])
>>> g1.nodes['game'].data['h1'] = th.tensor([[0.]])
>>> g1.edges['plays'].data['h1'] = th.tensor([[0.], [1.]])
>>> g2 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0)]
... })
>>> g2.nodes['user'].data['h1'] = th.tensor([[0.], [1.], [2.]])
>>> g2.nodes['user'].data['h2'] = th.tensor([[3.], [4.], [5.]])
>>> g2.nodes['game'].data['h1'] = th.tensor([[0.]])
>>> g2.edges['plays'].data['h1'] = th.tensor([[0.], [1.]])
Merge two :class:`~dgl.DGLHeteroGraph` objects into one :class:`BatchedDGLHeteroGraph` object.
When merging a list of graphs, we can choose to include only a subset of the attributes.
>>> # For edge types, only canonical edge types are allowed to avoid ambiguity.
>>> bg = dgl.batch_hetero([g1, g2], node_attrs={'user': ['h1', 'h2'], 'game': None},
... edge_attrs={('user', 'plays', 'game'): 'h1'})
>>> list(bg.nodes['user'].data.keys())
['h1', 'h2']
>>> list(bg.nodes['game'].data.keys())
[]
>>> list(bg.edges['follows'].data.keys())
[]
>>> list(bg.edges['plays'].data.keys())
['h1']
We can get a brief summary of the graphs that constitute the batched graph.
>>> bg.batch_size
2
>>> bg.batch_num_nodes('user')
[3, 3]
>>> bg.batch_num_edges(('user', 'plays', 'game'))
[2, 2]
Updating the attributes of the batched graph has no effect on the original graphs.
>>> bg.nodes['game'].data['h1'] = th.tensor([[1.], [1.]])
>>> g2.nodes['game'].data['h1']
tensor([[0.]])
Instead, we can decompose the batched graph back into a list of graphs and use them
to replace the original graphs.
>>> g3, g4 = dgl.unbatch_hetero(bg) # returns a list of DGLHeteroGraph objects
>>> g4.nodes['game'].data['h1']
tensor([[1.]])
"""
def __init__(self, graph_list, node_attrs, edge_attrs):
# Sanity check. Make sure all graphs have the same node/edge types, in the same order.
ref_graph = graph_list[0]
ref_canonical_etypes = ref_graph.canonical_etypes
ref_ntypes = ref_graph.ntypes
ref_etypes = ref_graph.etypes
for i in range(1, len(graph_list)):
g_i = graph_list[i]
assert g_i.ntypes == ref_ntypes, \
'The node types of graph {:d} and {:d} should be the same.'.format(0, i)
assert g_i.canonical_etypes == ref_canonical_etypes, \
'The canonical edge types of graph {:d} and {:d} should be the same.'.format(0, i)
# Sanity check. Make sure all graphs have same
# node/edge features in terns of name and size.
for nty in ref_ntypes:
ref_feats_nty = set(ref_graph.node_attr_schemes(nty).keys())
for i in range(1, len(graph_list)):
assert ref_feats_nty == set(graph_list[i].node_attr_schemes(nty).keys()), \
'The node features of graph {:d} and {:d} for ' \
'node type {} should be the same.'.format(0, i, nty)
for nfeats in ref_feats_nty:
assert ref_graph.node_attr_schemes(nty)[nfeats] == \
graph_list[i].node_attr_schemes(nty)[nfeats], \
'For graph {:d} and {:d}, the size and dtype for feature ' \
'{} of {}-typed nodes should be the same.'.format(0, i, nfeats, nty)
for ety in ref_canonical_etypes:
ref_feats_ety = set(ref_graph.edge_attr_schemes(ety).keys())
for i in range(1, len(graph_list)):
assert ref_feats_ety == set(graph_list[i].edge_attr_schemes(ety).keys()), \
'The edge features of graph {:d} and {:d} for ' \
'edge type {} should be the same.'.format(0, i, ety)
for efeats in ref_feats_ety:
assert ref_graph.edge_attr_schemes(ety)[efeats] == \
graph_list[i].edge_attr_schemes(ety)[efeats], \
'For graph {:d} and {:d}, the size and dtype for feature ' \
'{} of {}-typed edge should be the same.'.format(0, i, efeats, ety)
def _init_attrs(types, attrs, mode):
formatted_attrs = {t: [] for t in types}
if is_all(attrs):
for typ in types:
if mode == 'node':
formatted_attrs[typ] = list(ref_graph.node_attr_schemes(typ).keys())
elif mode == 'edge':
formatted_attrs[typ] = list(ref_graph.edge_attr_schemes(typ).keys())
elif isinstance(attrs, dict):
for typ, v in attrs.items():
if isinstance(v, str):
formatted_attrs[typ] = [v]
elif isinstance(v, Iterable):
formatted_attrs[typ] = list(v)
elif v is not None:
raise ValueError('Expected {} attrs for type {} to be str '
'or iterable, got {}'.format(mode, typ, type(v)))
elif attrs is not None:
raise ValueError('Expected {} attrs to be of type None or dict,'
'got type {}'.format(mode, type(attrs)))
return formatted_attrs
node_attrs = _init_attrs(ref_ntypes, node_attrs, 'node')
edge_attrs = _init_attrs(ref_canonical_etypes, edge_attrs, 'edge')
node_frames = []
for tid, typ in enumerate(ref_ntypes):
if len(node_attrs[typ]) == 0:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
node_frames.append(None)
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._node_frames[tid][key] for gr in graph_list
if gr.number_of_nodes(typ) > 0], dim=0)
for key in node_attrs[typ]}
node_frames.append(FrameRef(Frame(cols)))
edge_frames = []
for tid, typ in enumerate(ref_canonical_etypes):
if len(edge_attrs[typ]) == 0:
# Emtpy frames will be created when we instantiate a DGLHeteroGraph.
edge_frames.append(None)
else:
# NOTE: following code will materialize the columns of the input graphs.
cols = {key: F.cat([gr._edge_frames[tid][key] for gr in graph_list
if gr.number_of_edges(typ) > 0], dim=0)
for key in edge_attrs[typ]}
edge_frames.append(FrameRef(Frame(cols)))
# Create graph index for the batched graph
metagraph = graph_list[0]._graph.metagraph
batched_index = heterograph_index.disjoint_union(
metagraph, [g._graph for g in graph_list])
super(BatchedDGLHeteroGraph, self).__init__(gidx=batched_index,
ntypes=ref_ntypes,
etypes=ref_etypes,
node_frames=node_frames,
edge_frames=edge_frames)
# extra members
self._batch_size = 0
# Store number of nodes/edge based on the id of node/edge types as we need
# to handle both edge type and canonical edge type.
self._batch_num_nodes = [[] for _ in range(len(ref_ntypes))]
self._batch_num_edges = [[] for _ in range(len(ref_etypes))]
for grh in graph_list:
if isinstance(grh, BatchedDGLHeteroGraph):
# Handle input graphs that are already batched
self._batch_size += grh._batch_size
for ntype_id in range(len(ref_ntypes)):
self._batch_num_nodes[ntype_id].extend(grh._batch_num_nodes[ntype_id])
for etype_id in range(len(ref_etypes)):
self._batch_num_edges[etype_id].extend(grh._batch_num_edges[etype_id])
else:
self._batch_size += 1
for ntype_id in range(len(ref_ntypes)):
self._batch_num_nodes[ntype_id].append(grh._graph.number_of_nodes(ntype_id))
for etype_id in range(len(ref_etypes)):
self._batch_num_edges[etype_id].append(grh._graph.number_of_edges(etype_id))
@property
def batch_size(self):
"""Number of graphs in this batch.
Returns
-------
int
Number of graphs in this batch."""
return self._batch_size
def batch_num_nodes(self, ntype=None):
"""Return the numbers of nodes of the given type for all heterographs in the batch.
Parameters
----------
ntype : str, optional
The node type. Can be omitted if there is only one node type
in the graph. (Default: None)
Returns
-------
list of int
The ith element gives the number of nodes of the specified type in the ith graph.
Examples
--------
>>> g1 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
... })
>>> g2 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
... })
>>> bg = dgl.batch_hetero([g1, g2])
>>> bg.batch_num_nodes('user')
[4, 3]
>>> bg.batch_num_nodes('game')
[2, 2]
"""
return self._batch_num_nodes[self.get_ntype_id(ntype)]
def batch_num_edges(self, etype=None):
"""Return the numbers of edges of the given type for all heterographs in the batch.
Parameters
----------
etype : str or tuple of str, optional
The edge type. Can be omitted if there is only one edge type
in the graph.
Returns
-------
list of int
The ith element gives the number of edges of the specified type in the ith graph.
Examples
--------
>>> g1 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
... })
>>> g2 = dgl.heterograph({
... ('user', 'follows', 'user'): [(0, 1), (1, 2)],
... ('user', 'follows', 'developer'): [(0, 1), (1, 2)],
... ('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
... })
>>> bg = dgl.batch_hetero([g1, g2])
>>> bg.batch_num_edges('plays')
[4, 3]
>>> # 'follows' is ambiguous and we use ('user', 'follows', 'user') instead.
>>> bg.batch_num_edges(('user', 'follows', 'user'))
[2, 2]
"""
return self._batch_num_edges[self.get_etype_id(etype)]
def unbatch_hetero(graph):
"""Return the list of heterographs in this batch.
Parameters
----------
graph : BatchedDGLHeteroGraph
The batched heterograph.
Returns
-------
list
A list of :class:`~dgl.BatchedDGLHeteroGraph` objects whose attributes are
obtained by partitioning the attributes of the :attr:`graph`. The length of
the list is the same as the batch size of :attr:`graph`.
Notes
-----
Unbatching will break each field tensor of the batched graph into smaller
partitions.
For simpler tasks such as node/edge state aggregation, try to slice graphs along
edge types and use readout functions.
See Also
--------
batch_hetero
"""
assert isinstance(graph, BatchedDGLHeteroGraph), \
'Expect the input to be of type BatchedDGLHeteroGraph, got type {}'.format(type(graph))
bsize = graph.batch_size
bnn_all_types = graph._batch_num_nodes
bne_all_types = graph._batch_num_edges
ntypes = graph._ntypes
etypes = graph._etypes
node_frames = [[FrameRef(Frame(num_rows=bnn_all_types[tid][i])) for tid in range(len(ntypes))]
for i in range(bsize)]
edge_frames = [[FrameRef(Frame(num_rows=bne_all_types[tid][i])) for tid in range(len(etypes))]
for i in range(bsize)]
for tid in range(len(ntypes)):
for attr, col in graph._node_frames[tid].items():
col_splits = F.split(col, bnn_all_types[tid], dim=0)
for i in range(bsize):
node_frames[i][tid][attr] = col_splits[i]
for tid in range(len(etypes)):
for attr, col in graph._edge_frames[tid].items():
col_splits = F.split(col, bne_all_types[tid], dim=0)
for i in range(bsize):
edge_frames[i][tid][attr] = col_splits[i]
unbatched_graph_indices = heterograph_index.disjoint_partition(
graph._graph, bnn_all_types, bne_all_types)
return [DGLHeteroGraph(gidx=unbatched_graph_indices[i],
ntypes=ntypes,
etypes=etypes,
node_frames=node_frames[i],
edge_frames=edge_frames[i]) for i in range(bsize)]
def batch_hetero(graph_list, node_attrs=ALL, edge_attrs=ALL):
"""Batch a collection of :class:`~dgl.DGLHeteroGraph` and return a
:class:`BatchedDGLHeteroGraph` object that is independent of the :attr:`graph_list`.
Parameters
----------
graph_list : iterable
A collection of :class:`~dgl.DGLHeteroGraph` to be batched.
node_attrs : None or dict
The node attributes to be batched. If ``None``, the resulted graph will not have
features. If ``dict``, it maps str to str or iterable. The keys represent names of
node types and the values represent the node features to be batched for the
corresponding type. By default, we use all features for all types of nodes.
edge_attrs : None or dict
Same as for the case of :attr:`node_attrs`.
Returns
-------
BatchedDGLHeteroGraph
One single batched heterograph
See Also
--------
BatchedDGLHeteroGraph
unbatch_hetero
"""
return BatchedDGLHeteroGraph(graph_list, node_attrs, edge_attrs)
...@@ -235,7 +235,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): ...@@ -235,7 +235,13 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
raise DGLError('Unsupported graph data type:', type(data)) raise DGLError('Unsupported graph data type:', type(data))
def hetero_from_relations(rel_graphs): def hetero_from_relations(rel_graphs):
"""Create a heterograph from per-relation graphs. """Create a heterograph from graphs representing connections of each relation.
The input is a list of heterographs where the ``i``th graph contains edges of type
:math:`(s_i, e_i, d_i)`.
If two graphs share a same node type, the number of nodes for the corresponding type
should be the same. See **Examples** for details.
Parameters Parameters
---------- ----------
...@@ -246,6 +252,52 @@ def hetero_from_relations(rel_graphs): ...@@ -246,6 +252,52 @@ def hetero_from_relations(rel_graphs):
------- -------
DGLHeteroGraph DGLHeteroGraph
A heterograph consisting of all relations. A heterograph consisting of all relations.
Examples
--------
>>> import dgl
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')
>>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game')
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
will raise an error as we have 3 nodes of type 'user' in follows_g and 4 nodes of type
'user' in plays_g.
We have two possible methods to avoid the construction.
**Method 1**: Manually specify the number of nodes for all types when constructing
the relation graphs.
>>> # A graph with 4 nodes of type 'user'
>>> follows_g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows', card=4)
>>> # A bipartite graph with 4 nodes of src type ('user') and 2 nodes of dst type ('game')
>>> plays_g = dgl.bipartite([(0, 0), (3, 1)], 'user', 'plays', 'game', card=(4, 2))
>>> devs_g = dgl.bipartite([(0, 0), (1, 1)], 'developer', 'develops', 'game')
>>> g = dgl.hetero_from_relations([follows_g, plays_g, devs_g])
>>> print(g)
Graph(num_nodes={'user': 4, 'game': 2, 'developer': 2},
num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 2,
('developer', 'develops', 'game'): 2},
metagraph=[('user', 'user'), ('user', 'game'), ('developer', 'game')])
``devs_g`` does not have nodes of type ``'user'`` so no error will be raised.
**Method 2**: Construct a heterograph at once without intermediate relation graphs,
in which case we will infer the number of nodes for each type.
>>> g = dgl.heterograph({
>>> ('user', 'follows', 'user'): [(0, 1), (1, 2)],
>>> ('user', 'plays', 'game'): [(0, 0), (3, 1)],
>>> ('developer', 'develops', 'game'): [(0, 0), (1, 1)]
>>> })
>>> print(g)
Graph(num_nodes={'user': 4, 'game': 2, 'developer': 2},
num_edges={('user', 'follows', 'user'): 2,
('user', 'plays', 'game'): 2,
('developer', 'develops', 'game'): 2},
metagraph=[('user', 'user'), ('user', 'game'), ('developer', 'game')])
""" """
# TODO(minjie): this API can be generalized as a union operation of the input graphs # TODO(minjie): this API can be generalized as a union operation of the input graphs
# TODO(minjie): handle node/edge data # TODO(minjie): handle node/edge data
...@@ -412,18 +464,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -412,18 +464,18 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
>>> hetero_g = dgl.hetero_from_relations([g1, g2]) >>> hetero_g = dgl.hetero_from_relations([g1, g2])
>>> print(hetero_g) >>> print(hetero_g)
Graph(num_nodes={'user': 2, 'activity': 3, 'developer': 2, 'game': 2}, Graph(num_nodes={'user': 2, 'activity': 3, 'developer': 2, 'game': 2},
num_edges={'develops': 2}, num_edges={('user', 'develops', 'activity'): 2, ('developer', 'develops', 'game'): 2},
metagraph=[('user', 'activity'), ('developer', 'game')]) metagraph=[('user', 'activity'), ('developer', 'game')])
We first convert the heterogeneous graph to a homogeneous graph. We first convert the heterogeneous graph to a homogeneous graph.
>>> homo_g = dgl.to_homo(hetero_g) >>> homo_g = dgl.to_homo(hetero_g)
>>> print(homo_g) >>> print(homo_g)
Graph(num_nodes=9, num_edges=4, Graph(num_nodes=9, num_edges=4,
ndata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64), ndata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)} '_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64), edata_schemes={'_TYPE': Scheme(shape=(), dtype=torch.int64),
'_ID': Scheme(shape=(), dtype=torch.int64)}) '_ID': Scheme(shape=(), dtype=torch.int64)})
>>> homo_g.ndata >>> homo_g.ndata
{'_TYPE': tensor([0, 0, 1, 1, 1, 2, 2, 3, 3]), '_ID': tensor([0, 1, 0, 1, 2, 0, 1, 0, 1])} {'_TYPE': tensor([0, 0, 1, 1, 1, 2, 2, 3, 3]), '_ID': tensor([0, 1, 0, 1, 2, 0, 1, 0, 1])}
Nodes 0, 1 for 'user', 2, 3, 4 for 'activity', 5, 6 for 'developer', 7, 8 for 'game' Nodes 0, 1 for 'user', 2, 3, 4 for 'activity', 5, 6 for 'developer', 7, 8 for 'game'
...@@ -436,8 +488,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -436,8 +488,8 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
>>> hetero_g_2 = dgl.to_hetero(homo_g, hetero_g.ntypes, hetero_g.etypes) >>> hetero_g_2 = dgl.to_hetero(homo_g, hetero_g.ntypes, hetero_g.etypes)
>>> print(hetero_g_2) >>> print(hetero_g_2)
Graph(num_nodes={'user': 2, 'activity': 3, 'developer': 2, 'game': 2}, Graph(num_nodes={'user': 2, 'activity': 3, 'developer': 2, 'game': 2},
num_edges={'develops': 2}, num_edges={('user', 'develops', 'activity'): 2, ('developer', 'develops', 'game'): 2},
metagraph=[('user', 'activity'), ('developer', 'game')]) metagraph=[('user', 'activity'), ('developer', 'game')])
See Also See Also
-------- --------
......
...@@ -264,7 +264,7 @@ class DGLHeteroGraph(object): ...@@ -264,7 +264,7 @@ class DGLHeteroGraph(object):
' metagraph={meta})') ' metagraph={meta})')
nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i) nnode_dict = {self.ntypes[i] : self._graph.number_of_nodes(i)
for i in range(len(self.ntypes))} for i in range(len(self.ntypes))}
nedge_dict = {self.etypes[i] : self._graph.number_of_edges(i) nedge_dict = {self.canonical_etypes[i] : self._graph.number_of_edges(i)
for i in range(len(self.etypes))} for i in range(len(self.etypes))}
meta = str(self.metagraph.edges()) meta = str(self.metagraph.edges())
return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta) return ret.format(node=nnode_dict, edge=nedge_dict, meta=meta)
...@@ -1501,7 +1501,7 @@ class DGLHeteroGraph(object): ...@@ -1501,7 +1501,7 @@ class DGLHeteroGraph(object):
>>> sub_g = g.subgraph({'user': [1, 2]}) >>> sub_g = g.subgraph({'user': [1, 2]})
>>> print(sub_g) >>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 0}, Graph(num_nodes={'user': 2, 'game': 0},
num_edges={'plays': 0, 'follows': 2}, num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')]) metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices. Get the original node/edge indices.
...@@ -1582,7 +1582,7 @@ class DGLHeteroGraph(object): ...@@ -1582,7 +1582,7 @@ class DGLHeteroGraph(object):
>>> ('user', 'plays', 'game'): [2]}) >>> ('user', 'plays', 'game'): [2]})
>>> print(sub_g) >>> print(sub_g)
Graph(num_nodes={'user': 2, 'game': 1}, Graph(num_nodes={'user': 2, 'game': 1},
num_edges={'plays': 1, 'follows': 2}, num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
metagraph=[('user', 'game'), ('user', 'user')]) metagraph=[('user', 'game'), ('user', 'user')])
Get the original node/edge indices. Get the original node/edge indices.
......
"""Module for heterogeneous graph index class definition.""" """Module for heterogeneous graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import itertools
import numpy as np import numpy as np
import scipy import scipy
...@@ -1006,6 +1007,45 @@ def create_heterograph_from_relations(metagraph, rel_graphs): ...@@ -1006,6 +1007,45 @@ def create_heterograph_from_relations(metagraph, rel_graphs):
""" """
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs) return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
def disjoint_union(metagraph, graphs):
"""Return a disjoint union of the input heterographs.
Parameters
----------
metagraph : GraphIndex
Meta-graph.
graphs : list of HeteroGraphIndex
Heterographs to be batched.
Returns
-------
HeteroGraphIndex
Batched Heterograph.
"""
return _CAPI_DGLHeteroDisjointUnion(metagraph, graphs)
def disjoint_partition(graph, bnn_all_types, bne_all_types):
"""Partition the graph disjointly.
Parameters
----------
graph : HeteroGraphIndex
The graph to be partitioned.
bnn_all_types : list of list of int
bnn_all_types[t] gives the number of nodes with t-th type in the batch.
bne_all_types : list of list of int
bne_all_types[t] gives the number of edges with t-th type in the batch.
Returns
--------
list of HeteroGraphIndex
Heterographs unbatched.
"""
bnn_all_types = utils.toindex(list(itertools.chain.from_iterable(bnn_all_types)))
bne_all_types = utils.toindex(list(itertools.chain.from_iterable(bne_all_types)))
return _CAPI_DGLHeteroDisjointPartitionBySizes(
graph, bnn_all_types.todgltensor(), bne_all_types.todgltensor())
@register_object("graph.FlattenedHeteroGraph") @register_object("graph.FlattenedHeteroGraph")
class FlattenedHeteroGraph(ObjectBase): class FlattenedHeteroGraph(ObjectBase):
"""FlattenedHeteroGraph object class in C++ backend.""" """FlattenedHeteroGraph object class in C++ backend."""
......
...@@ -294,6 +294,125 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp ...@@ -294,6 +294,125 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(const std::vector<dgl_type_t>& etyp
return FlattenedHeteroGraphPtr(result); return FlattenedHeteroGraphPtr(result);
} }
HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
dgl_id_t src_offset = 0, dst_offset = 0;
std::vector<dgl_id_t> result_src, result_dst;
// Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
// Loop over all edges
for (size_t j = 0; j < num_edges; ++j) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[j] + src_offset);
result_dst.push_back(edges_dst_data[j] + dst_offset);
}
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
src_offset,
dst_offset,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[etype] = rgptr;
}
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
}
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
// Sanity check for vertex sizes
const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum
std::vector<std::vector<uint64_t>> vertex_cumsum;
vertex_cumsum.resize(num_vertex_types);
// Loop over all vertex types
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_cumsum[vtype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
}
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype;
}
// Sanity check for edge sizes
const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);
const uint64_t num_edge_types = meta_graph->NumEdges();
// Map edge type to the corresponding edge cum sum
std::vector<std::vector<uint64_t>> edge_cumsum;
edge_cumsum.resize(num_edge_types);
// Loop over all edge types
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
edge_cumsum[etype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
}
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype;
}
// Construct relation graphs for unbatched graphs
std::vector<std::vector<HeteroGraphPtr>> rel_graphs;
rel_graphs.resize(batch_size);
// Loop over all edge types
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
EdgeArray edges = batched_graph->Edges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
// Loop over all graphs to be unbatched
for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<dgl_id_t> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
vertex_sizes_data[src_vtype * batch_size + g],
vertex_sizes_data[dst_vtype * batch_size + g],
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[g].push_back(rgptr);
}
}
std::vector<HeteroGraphPtr> rst;
for (uint64_t g = 0; g < batch_size; ++g) {
rst.push_back(HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs[g])));
}
return rst;
}
// creator implementation // creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
...@@ -371,6 +490,33 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") ...@@ -371,6 +490,33 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec)); *rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
}
auto hgptr = DisjointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
const auto& ret = DisjointPartitionHeteroBySizes(
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
import dgl import dgl
from dgl import DGLGraph
import backend as F import backend as F
def tree1(): def tree1():
......
import dgl
import backend as F
from dgl.base import ALL
def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
assert g1.ntypes == g2.ntypes
assert g1.etypes == g2.etypes
assert g1.canonical_etypes == g2.canonical_etypes
for nty in g1.ntypes:
assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)
for ety in g1.etypes:
if len(g1._etype2canonical[ety]) > 0:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
for ety in g1.canonical_etypes:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
src1, dst1 = g1.all_edges(etype=ety)
src2, dst2 = g2.all_edges(etype=ety)
assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2)
if node_attrs is not None:
for nty in node_attrs.keys():
for feat_name in node_attrs[nty]:
assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name])
if edge_attrs is not None:
for ety in edge_attrs.keys():
for feat_name in edge_attrs[ety]:
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])
def test_batching_hetero_topology():
"""Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1), (3, 1)]
})
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'follows', 'developer'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0), (2, 1)]
})
bg = dgl.batch_hetero([g1, g2])
assert bg.ntypes == g2.ntypes
assert bg.etypes == g2.etypes
assert bg.canonical_etypes == g2.canonical_etypes
assert bg.batch_size == 2
# Test number of nodes
for ntype in bg.ntypes:
assert bg.batch_num_nodes(ntype) == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
assert bg.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))
# Test number of edges
assert bg.batch_num_edges('plays') == [
g1.number_of_edges('plays'), g2.number_of_edges('plays')]
assert bg.number_of_edges('plays') == (
g1.number_of_edges('plays') + g2.number_of_edges('plays'))
for etype in bg.canonical_etypes:
assert bg.batch_num_edges(etype) == [
g1.number_of_edges(etype), g2.number_of_edges(etype)]
assert bg.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype))
# Test relabeled nodes
for ntype in bg.ntypes:
assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))
# Test relabeled edges
src, dst = bg.all_edges(etype=('user', 'follows', 'user'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
src, dst = bg.all_edges(etype=('user', 'follows', 'developer'))
assert list(F.asnumpy(src)) == [0, 1, 4, 5]
assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
src, dst = bg.all_edges(etype='plays')
assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
# Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(g1, g3)
check_equivalence_between_heterographs(g2, g4)
def test_batching_hetero_and_batched_hetero_topology():
"""Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
bg1 = dgl.batch_hetero([g1, g2])
g3 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1)],
('user', 'plays', 'game'): [(1, 0)]
})
bg2 = dgl.batch_hetero([bg1, g3])
assert bg2.ntypes == g3.ntypes
assert bg2.etypes == g3.etypes
assert bg2.canonical_etypes == g3.canonical_etypes
assert bg2.batch_size == 3
# Test number of nodes
for ntype in bg2.ntypes:
assert bg2.batch_num_nodes(ntype) == [
g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)]
assert bg2.number_of_nodes(ntype) == (
g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype))
# Test number of edges
for etype in bg2.etypes:
assert bg2.batch_num_edges(etype) == [
g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
assert bg2.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))
for etype in bg2.canonical_etypes:
assert bg2.batch_num_edges(etype) == [
g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
assert bg2.number_of_edges(etype) == (
g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))
# Test relabeled nodes
for ntype in bg2.ntypes:
assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype)))
# Test relabeled edges
src, dst = bg2.all_edges(etype='follows')
assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
src, dst = bg2.all_edges(etype='plays')
assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]
# Test unbatching graphs
g4, g5, g6 = dgl.unbatch_hetero(bg2)
check_equivalence_between_heterographs(g1, g4)
check_equivalence_between_heterographs(g2, g5)
check_equivalence_between_heterographs(g3, g6)
def test_batched_features():
"""Test the features of batched DGLHeteroGraphs"""
g1 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g1.nodes['game'].data['h1'] = F.tensor([[0.]])
g1.nodes['game'].data['h2'] = F.tensor([[1.]])
g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])
g2 = dgl.heterograph({
('user', 'follows', 'user'): [(0, 1), (1, 2)],
('user', 'plays', 'game'): [(0, 0), (1, 0)]
})
g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
g2.nodes['game'].data['h1'] = F.tensor([[0.]])
g2.nodes['game'].data['h2'] = F.tensor([[1.]])
g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])
bg = dgl.batch_hetero([g1, g2],
node_attrs=ALL,
edge_attrs={
('user', 'follows', 'user'): 'h1',
('user', 'plays', 'game'): None
})
assert F.allclose(bg.nodes['user'].data['h1'],
F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
assert F.allclose(bg.nodes['user'].data['h2'],
F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
assert F.allclose(bg.nodes['game'].data['h1'],
F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0))
assert F.allclose(bg.nodes['game'].data['h2'],
F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0))
assert F.allclose(bg.edges['follows'].data['h1'],
F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
assert 'h2' not in bg.edges['follows'].data.keys()
assert 'h1' not in bg.edges['plays'].data.keys()
# Test unbatching graphs
g3, g4 = dgl.unbatch_hetero(bg)
check_equivalence_between_heterographs(
g1, g3,
node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
edge_attrs={('user', 'follows', 'user'): ['h1']})
check_equivalence_between_heterographs(
g2, g4,
node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
edge_attrs={('user', 'follows', 'user'): ['h1']})
if __name__ == '__main__':
test_batching_hetero_topology()
test_batching_hetero_and_batched_hetero_topology()
test_batched_features()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment