"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b9f74ff3d643fd2b579bfc24132e34f3bc7bf777"
Unverified Commit f960468f authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[feature] APIs for manually set batch_num_nodes and batch_num_edges (#2430)



* wip

* udp

* upd

* fix typo

* lint

* lint

* upd

* upd

* lint

* lint

* upd

* upd

* tftest

* fix

* fallback numpy

* fix tf

* docstring
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent d3c16455
......@@ -86,6 +86,8 @@ operators for computing graph-level representation for both single and batched g
batch
unbatch
set_batch_num_nodes
set_batch_num_edges
readout_nodes
readout_edges
sum_nodes
......
......@@ -259,7 +259,7 @@ def stack(seq, dim):
def split(input, sizes_or_sections, dim):
return tf.split(input, sizes_or_sections, axis=dim)
return [copy_to(_, input.device) for _ in tf.split(input, sizes_or_sections, axis=dim)]
def repeat(input, repeats, dim):
......
......@@ -1257,7 +1257,80 @@ class DGLHeteroGraph(object):
return self._batch_num_nodes[ntype]
def set_batch_num_nodes(self, val):
"""TBD"""
"""Manually set the number of nodes for each graph in the batch with the specified node
type.
Parameters
----------
val : Tensor or Mapping[str, Tensor]
The dictionary storing number of nodes for each graph in the batch for all node types.
If the graph has only one node type, ``val`` can also be a single array indicating the
number of nodes per graph in the batch.
Notes
-----
This API is always used together with ``set_batch_num_edges`` to specify batching
information of a graph, it also do not check the correspondance between the graph structure
and batching information and user must guarantee there will be no cross-graph edges in the
batch.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={}), Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})]
Create a heterogeneous graph.
>>> hg = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]),
... ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})
Manually set batch information.
>>> hg.set_batch_num_nodes({
... 'user': torch.tensor([3, 3]),
... 'game': torch.tensor([2, 2]),
... 'developer': torch.tensor([2, 2])})
>>> hg.set_batch_num_edges({
... ('user', 'plays', 'game'): torch.tensor([3, 3]),
... ('developer', 'develops', 'game'): torch.tensor([2, 2])})
Unbatch the graph.
>>> g1, g2 = dgl.unbatch(hg)
>>> g1
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
>>> g2
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
See Also
--------
set_batch_num_edges
batch
unbatch
"""
if not isinstance(val, Mapping):
if len(self.ntypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple node types.')
......@@ -1326,7 +1399,80 @@ class DGLHeteroGraph(object):
return self._batch_num_edges[etype]
def set_batch_num_edges(self, val):
"""TBD"""
"""Manually set the number of edges for each graph in the batch with the specified edge
type.
Parameters
----------
val : Tensor or Mapping[str, Tensor]
The dictionary storing number of edges for each graph in the batch for all edge types.
If the graph has only one edge type, ``val`` can also be a single array indicating the
number of edges per graph in the batch.
Notes
-----
This API is always used together with ``set_batch_num_edges`` to specify batching
information of a graph, it also do not check the correspondance between the graph structure
and batching information and user must guarantee there will be no cross-graph edges in the
batch.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])
>>> g.set_batch_num_edges(torch.tensor([3, 3])
Unbatch the graph.
>>> dgl.unbatch(g)
[Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={}), Graph(num_nodes=3, num_edges=3,
ndata_schemes={}
edata_schemes={})]
Create a heterogeneous graph.
>>> hg = dgl.heterograph({
... ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]),
... ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})
Manually set batch information.
>>> hg.set_batch_num_nodes({
... 'user': torch.tensor([3, 3]),
... 'game': torch.tensor([2, 2]),
... 'developer': torch.tensor([2, 2])})
>>> hg.set_batch_num_edges(
... {('user', 'plays', 'game'): torch.tensor([3, 3]),
... ('developer', 'develops', 'game'): torch.tensor([2, 2])})
Unbatch the graph.
>>> g1, g2 = dgl.unbatch(hg)
>>> g1
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
>>> g2
Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3},
num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3},
metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
See Also
--------
set_batch_num_nodes
batch
unbatch
"""
if not isinstance(val, Mapping):
if len(self.etypes) != 1:
raise DGLError('Must provide a dictionary when there are multiple edge types.')
......
import dgl
import numpy as np
import backend as F
import unittest
from test_utils import parametrize_dtype
......@@ -206,13 +207,83 @@ def test_batch_no_edge(idtype):
g3.add_nodes(1) # no edges
g = dgl.batch([g1, g3, g2]) # should not throw an error
def _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs):
"""Internal function to compute batch information for subgraphs.
Parameters
----------
keys : List[str]
The node/edge type keys.
induced_indices_arr : List[Tensor]
The induced node/edge index tensor for all node/edge types.
batch_num_objs : Tensor
Number of nodes/edges for each graph in the original batch.
Returns
-------
Mapping[str, Tensor]
A dictionary mapping all node/edge type keys to the ``batch_num_objs``
array of corresponding graph.
"""
bucket_offset = np.expand_dims(np.cumsum(F.asnumpy(batch_num_objs), 0), -1) # (num_bkts, 1)
ret = {}
for key, induced_indices in zip(keys, induced_indices_arr):
# NOTE(Zihao): this implementation is not efficient and we can replace it with
# binary search in the future.
induced_indices = np.expand_dims(F.asnumpy(induced_indices), 0) # (1, num_nodes)
new_offset = np.sum((induced_indices < bucket_offset), 1) # (num_bkts,)
# start_offset = [0] + [new_offset[i-1] for i in range(1, n_bkts)]
start_offset = np.concatenate([np.zeros((1,)), new_offset[:-1]], 0)
new_batch_num_objs = new_offset - start_offset
ret[key] = F.tensor(new_batch_num_objs, dtype=F.dtype(batch_num_objs))
return ret
@parametrize_dtype
def test_set_batch_info(idtype):
ctx = F.ctx()
g1 = dgl.rand_graph(30, 100).astype(idtype).to(F.ctx())
g2 = dgl.rand_graph(40, 200).astype(idtype).to(F.ctx())
bg = dgl.batch([g1, g2])
batch_num_nodes = F.astype(bg.batch_num_nodes(), idtype)
batch_num_edges = F.astype(bg.batch_num_edges(), idtype)
# test homogeneous node subgraph
sg_n = dgl.node_subgraph(bg, list(range(10, 20)) + list(range(50, 60)))
induced_nodes = sg_n.ndata['_ID']
induced_edges = sg_n.edata['_ID']
new_batch_num_nodes = _get_subgraph_batch_info(bg.ntypes, [induced_nodes], batch_num_nodes)
new_batch_num_edges = _get_subgraph_batch_info(bg.canonical_etypes, [induced_edges], batch_num_edges)
sg_n.set_batch_num_nodes(new_batch_num_nodes)
sg_n.set_batch_num_edges(new_batch_num_edges)
subg_n1, subg_n2 = dgl.unbatch(sg_n)
subg1 = dgl.node_subgraph(g1, list(range(10, 20)))
subg2 = dgl.node_subgraph(g2, list(range(20, 30)))
assert subg_n1.num_edges() == subg1.num_edges()
assert subg_n2.num_edges() == subg2.num_edges()
# test homogeneous edge subgraph
sg_e = dgl.edge_subgraph(bg, list(range(40, 70)) + list(range(150, 200)), preserve_nodes=True)
induced_nodes = sg_e.ndata['_ID']
induced_edges = sg_e.edata['_ID']
new_batch_num_nodes = _get_subgraph_batch_info(bg.ntypes, [induced_nodes], batch_num_nodes)
new_batch_num_edges = _get_subgraph_batch_info(bg.canonical_etypes, [induced_edges], batch_num_edges)
sg_e.set_batch_num_nodes(new_batch_num_nodes)
sg_e.set_batch_num_edges(new_batch_num_edges)
subg_e1, subg_e2 = dgl.unbatch(sg_e)
subg1 = dgl.edge_subgraph(g1, list(range(40, 70)), preserve_nodes=True)
subg2 = dgl.edge_subgraph(g2, list(range(50, 100)), preserve_nodes=True)
assert subg_e1.num_nodes() == subg1.num_nodes()
assert subg_e2.num_nodes() == subg2.num_nodes()
if __name__ == '__main__':
test_batch_unbatch()
test_batch_unbatch1()
test_batch_unbatch_frame()
#test_batch_unbatch()
#test_batch_unbatch1()
#test_batch_unbatch_frame()
#test_batch_unbatch2()
#test_batched_edge_ordering()
#test_batch_send_then_recv()
#test_batch_send_and_recv()
#test_batch_propagate()
#test_batch_no_edge()
test_set_batch_info(F.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment