Unverified Commit d3c16455 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Performance][Test] New low memory implementation for RGCN and related regression tests (#2468)

* WIP

* finish lowmem impl; benchmarking

* wip

* wip

* fix benchmarks

* fix bug in searchsorted

* update readme

* update

* lint and ogb dependency

* add flags to to_homogeneous

* fix docstring

* address comments
parent 80de340d
......@@ -18,7 +18,10 @@ To run all benchmarks locally, build the project first and then run:
asv run -n -e --python=same --verbose
```
Note that local run will not produce any benchmark results on disk.
**Due to ASV's restriction, `--python=same` will not write any benchmark results
to disk. It does not support specifying branches and commits either. They are only
available under ASV's managed environment.**
To change the device for benchmarking, set the `DGL_BENCH_DEVICE` environment variable.
Any valid PyTorch device strings are allowed.
......@@ -26,14 +29,29 @@ Any valid PyTorch device strings are allowed.
export DGL_BENCH_DEVICE=cuda:0
```
DGL runs all benchmarks automatically in docker container. To run all benchmarks in docker,
use the `publish.sh` script. It accepts two arguments, a name specifying the identity of
the test machine and a device name.
To select which benchmark to run, use the `--bench` flag. For example,
```bash
bash publish.sh dev-machine cuda:0
asv run -n -e --python=save --verbose --bench model_acc.bench_gat
```
Run in docker locally
---
DGL runs all benchmarks automatically in docker container. To run bencmarks in docker locally,
* Git commit your locally changes. No need to push to remote repository.
* To compare commits from different branches. Change the `"branches"` list in `asv.conf.json`.
The default is `"HEAD"` which is the last commit of the current branch. For example, to
compare your proposed changes with the master branch, set it to be `["HEAD", "master"]`.
If your workspace is a forked repository, make sure your local master has synced with
the upstream.
* Use the `publish.sh` script. It accepts two arguments, a name specifying the identity of
the test machine and a device name. For example,
```bash
bash publish.sh dev-machine cuda:0
```
The script will output two folders `results` and `html`. The `html` folder contains the
generated static web pages. View it by:
......@@ -41,6 +59,8 @@ generated static web pages. View it by:
asv preview
```
Please see `publish.sh` for more information on how it works and how to modify it according
to your need.
Adding a new benchmark suite
---
......@@ -104,15 +124,8 @@ def track_time(l, u):
Tips
----
* Feed flags `-e --verbose` to `asv run` to print out stderr and more information. Use `--bench` flag
to run specific benchmarks, e.g., `--bench bench_gat`.
* Feed flags `-e --verbose` to `asv run` to print out stderr and more information.
* When running benchmarks locally (e.g., with `--python=same`), ASV will not write results to disk
so `asv publish` will not generate plots.
* When running benchmarks in docker, ASV will pull the codes from remote and build them in conda
environment. The repository to pull is determined by `origin`, so it works with forked repository.
The branches are configured in `asv.conf.json`. If you wish to test the impact of your local source
code changes on performance in docker, remember to before running `publish.sh`:
- Commit your local changes and push it to remote `origin`.
- Add the corresponding branch to `asv.conf.json`.
* Try make your benchmarks compatible with all the versions being tested.
* For ogbn dataset, put the dataset into /tmp/dataset/
......@@ -27,7 +27,7 @@
],
// List of branches to benchmark. If not provided, defaults to "master"
// (for git) or "default" (for mercurial).
"branches": ["HEAD"], // for git
"branches": ["HEAD", "master"], // for git
// The DVCS being used. If not set, it will be automatically
// determined from "repo" by looking at the protocol in the URL
// (if remote), or by looking for special directories, such as
......
import numpy as np
import dgl
from dgl.nn.pytorch import RelGraphConv
import torch
......@@ -14,21 +15,22 @@ class RGCN(nn.Module):
num_rels,
num_bases,
num_hidden_layers,
dropout):
dropout,
low_mem):
super(RGCN, self).__init__()
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=True))
low_mem=low_mem))
# h2h
for i in range(num_hidden_layers):
self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=True))
low_mem=low_mem))
# o2h
self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis",
num_bases, activation=None, low_mem=True))
num_bases, activation=None, low_mem=low_mem))
def forward(self, g, h, r, norm):
for layer in self.layers:
......@@ -46,7 +48,9 @@ def evaluate(model, g, feats, edge_type, edge_norm, labels, idx):
@utils.benchmark('acc')
@utils.parametrize('data', ['aifb', 'mutag'])
def track_acc(data):
@utils.parametrize('lowmem', [True, False])
@utils.parametrize('use_type_count', [True, False])
def track_acc(data, lowmem, use_type_count):
# args
if data == 'aifb':
num_bases = -1
......@@ -87,10 +91,15 @@ def track_acc(data):
if ntype == category:
category_id = i
g = dgl.to_homogeneous(g, edata=['norm']).to(device)
if use_type_count:
g, _, edge_type = dgl.to_homogeneous(g, edata=['norm'], return_count=True)
g = g.to(device)
else:
g = dgl.to_homogeneous(g, edata=['norm']).to(device)
edge_type = g.edata.pop(dgl.ETYPE).long()
num_nodes = g.number_of_nodes()
edge_norm = g.edata['norm']
edge_type = g.edata[dgl.ETYPE].long()
# find out the target node ids in g
target_idx = torch.where(g.ndata[dgl.NTYPE] == category_id)[0]
......@@ -109,7 +118,8 @@ def track_acc(data):
num_rels,
num_bases,
0,
0).to(device)
0,
lowmem).to(device)
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2,
......
import time
import numpy as np
import dgl
from dgl.nn.pytorch import RelGraphConv
import torch
......@@ -15,21 +16,22 @@ class RGCN(nn.Module):
num_rels,
num_bases,
num_hidden_layers,
dropout):
dropout,
lowmem):
super(RGCN, self).__init__()
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=True))
low_mem=lowmem))
# h2h
for i in range(num_hidden_layers):
self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=True))
low_mem=lowmem))
# o2h
self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis",
num_bases, activation=None, low_mem=True))
num_bases, activation=None, low_mem=lowmem))
def forward(self, g, h, r, norm):
for layer in self.layers:
......@@ -37,8 +39,10 @@ class RGCN(nn.Module):
return h
@utils.benchmark('time', 3600)
@utils.parametrize('data', ['aifb', 'am'])
def track_time(data):
@utils.parametrize('data', ['aifb'])
@utils.parametrize('lowmem', [True, False])
@utils.parametrize('use_type_count', [True, False])
def track_time(data, lowmem, use_type_count):
# args
if data == 'aifb':
num_bases = -1
......@@ -77,10 +81,15 @@ def track_time(data):
if ntype == category:
category_id = i
g = dgl.to_homogeneous(g, edata=['norm']).to(device)
if use_type_count:
g, _, edge_type = dgl.to_homogeneous(g, edata=['norm'], return_count=True)
g = g.to(device)
else:
g = dgl.to_homogeneous(g, edata=['norm']).to(device)
edge_type = g.edata.pop(dgl.ETYPE).long()
num_nodes = g.number_of_nodes()
edge_norm = g.edata['norm']
edge_type = g.edata[dgl.ETYPE].long()
# find out the target node ids in g
target_idx = torch.where(g.ndata[dgl.NTYPE] == category_id)[0]
......@@ -99,7 +108,8 @@ def track_time(data):
num_rels,
num_bases,
0,
0).to(device)
0,
lowmem).to(device)
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2,
......
......@@ -5,7 +5,7 @@ set -e
. /opt/conda/etc/profile.d/conda.sh
pip install -r /asv/torch_gpu_pip.txt
pip install pandas rdflib
pip install pandas rdflib ogb
# install
pushd python
......
#!/bin/bash
# The script launches a docker container to run ASV benchmarks. We use the same docker
# image as our CI (i.e., dgllib/dgl-ci-gpu:conda). It performs the following steps:
#
# 1. Start a docker container of the given machine name. The machine name will be
# displayed on the generated website.
# 2. Copy `.git` into the container. It allows ASV to determine the repository information
# such as commit hash, branches, etc.
# 3. Copy this folder into the container including the ASV configuration file `asv.conf.json`.
# This means any changes to the files in this folder do not
# require a git commit. By contrast, to correctly benchmark your changes to the core
# library (e.g., "python/dgl"), you must call git commit first.
# 4. It then calls the `run.sh` script inside the container. It will invoke `asv run`.
# You can change the command such as specifying the benchmarks to run or adding some flags.
# 5. After benchmarking, it copies the generated `results` and `html` folders back to
# the host machine.
#
if [ $# -eq 2 ]; then
MACHINE=$1
DEVICE=$2
......
......@@ -16,6 +16,6 @@ export DGL_BENCH_DEVICE=$DEVICE
pushd $ROOT/benchmarks
cat asv.conf.json
asv machine --yes
asv run
asv run -e
asv publish
popd
......@@ -564,16 +564,21 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE,
return to_heterogeneous(G, ntypes, etypes, ntype_field=ntype_field,
etype_field=etype_field, metagraph=metagraph)
def to_homogeneous(G, ndata=None, edata=None):
def to_homogeneous(G, ndata=None, edata=None, store_type=True, return_count=False):
"""Convert a heterogeneous graph to a homogeneous graph and return.
Node and edge types of the input graph are stored as the ``dgl.NTYPE``
and ``dgl.ETYPE`` features in the returned graph.
By default, the function stores the node and edge types of the input graph as
the ``dgl.NTYPE`` and ``dgl.ETYPE`` features in the returned graph.
Each feature is an integer representing the type id, determined by the
:meth:`DGLGraph.get_ntype_id` and :meth:`DGLGraph.get_etype_id` methods.
One can omit it by specifying ``store_type=False``.
The function also stores the original node/edge IDs as the ``dgl.NID``
and ``dgl.EID`` features in the returned graph.
The result graph assigns nodes and edges of the same type with IDs in continuous range
(i.e., nodes of the first type have IDs 0 ~ ``G.num_nodes(G.ntypes[0])``; nodes
of the second type come after; so on and so forth). Therefore, a more memory-efficient
format for type information is an integer list; the i^th corresponds to
the number of nodes/edges of the i^th type. One can choose this format by
specifying ``return_count=True``.
Parameters
----------
......@@ -589,11 +594,31 @@ def to_homogeneous(G, ndata=None, edata=None):
:attr:`edata`, it concatenates ``G.edges[T].data[feat]`` across all edge types ``T``.
As a result, the feature ``feat`` of all edge types should have the same shape and
data type. By default, the returned graph will not have any edge features.
store_type : bool, optional
If True, store type information as the ``dgl.NTYPE`` and ``dgl.ETYPE`` features
in the returned graph.
return_count : bool, optional
If True, return type information as an integer list; the i^th element corresponds to
the number of nodes/edges of the i^th type.
Returns
-------
DGLGraph
A homogeneous graph.
ntype_count : list[int], optional
Number of nodes of each type. Return when ``return_count`` is True.
etype_count : list[int], optional
Number of edges of each type. Return when ``return_count`` is True.
Notes
-----
* Calculating type information may introduce noticeable cost. Setting both ``store_type``
and ``return_count`` to False can avoid such cost if type information is not needed.
Otherwise, DGL recommends to use ``store_type=False`` and ``return_count=True`` due
to its memory efficiency.
* The ``ntype_count`` and ``etype_count`` lists can help speed up some operations.
See :class:`~dgl.nn.pytorch.conv.RelGraphConv` for such an example.
Examples
--------
......@@ -633,18 +658,25 @@ def to_homogeneous(G, ndata=None, edata=None):
offset_per_ntype = np.insert(np.cumsum(num_nodes_per_ntype), 0, 0)
srcs = []
dsts = []
etype_ids = []
eids = []
ntype_ids = []
nids = []
eids = []
if store_type:
ntype_ids = []
etype_ids = []
if return_count:
ntype_count = []
etype_count = []
total_num_nodes = 0
for ntype_id, ntype in enumerate(G.ntypes):
num_nodes = G.number_of_nodes(ntype)
total_num_nodes += num_nodes
# Type ID is always in int64
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, F.cpu()))
nids.append(F.arange(0, num_nodes, G.idtype))
if store_type:
# Type ID is always in int64
ntype_ids.append(F.full_1d(num_nodes, ntype_id, F.int64, G.device))
if return_count:
ntype_count.append(num_nodes)
nids.append(F.arange(0, num_nodes, G.idtype, G.device))
for etype_id, etype in enumerate(G.canonical_etypes):
srctype, _, dsttype = etype
......@@ -652,9 +684,12 @@ def to_homogeneous(G, ndata=None, edata=None):
num_edges = len(src)
srcs.append(src + int(offset_per_ntype[G.get_ntype_id(srctype)]))
dsts.append(dst + int(offset_per_ntype[G.get_ntype_id(dsttype)]))
# Type ID is always in int64
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, F.cpu()))
eids.append(F.arange(0, num_edges, G.idtype))
if store_type:
# Type ID is always in int64
etype_ids.append(F.full_1d(num_edges, etype_id, F.int64, G.device))
if return_count:
etype_count.append(num_edges)
eids.append(F.arange(0, num_edges, G.idtype, G.device))
retg = graph((F.cat(srcs, 0), F.cat(dsts, 0)), num_nodes=total_num_nodes,
idtype=G.idtype, device=G.device)
......@@ -671,13 +706,16 @@ def to_homogeneous(G, ndata=None, edata=None):
if comb_ef is not None:
retg.edata.update(comb_ef)
# assign node type and id mapping field.
retg.ndata[NTYPE] = F.copy_to(F.cat(ntype_ids, 0), G.device)
retg.ndata[NID] = F.copy_to(F.cat(nids, 0), G.device)
retg.edata[ETYPE] = F.copy_to(F.cat(etype_ids, 0), G.device)
retg.edata[EID] = F.copy_to(F.cat(eids, 0), G.device)
retg.ndata[NID] = F.cat(nids, 0)
retg.edata[EID] = F.cat(eids, 0)
if store_type:
retg.ndata[NTYPE] = F.cat(ntype_ids, 0)
retg.edata[ETYPE] = F.cat(etype_ids, 0)
return retg
if return_count:
return retg, ntype_count, etype_count
else:
return retg
def to_homo(G):
"""Convert the given heterogeneous graph to a homogeneous graph.
......
"""Torch Module for Relational graph convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import functools
import numpy as np
import torch as th
from torch import nn
from .... import function as fn
from .. import utils
from ....base import DGLError
from .... import edge_subgraph
class RelGraphConv(nn.Module):
r"""
Description
-----------
Relational graph convolution layer.
r"""Relational graph convolution layer.
Relational graph convolution is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
......@@ -185,8 +184,22 @@ class RelGraphConv(nn.Module):
self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges):
"""Message function for basis regularizer"""
def basis_message_func(self, edges, etypes):
"""Message function for basis regularizer.
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
......@@ -196,59 +209,91 @@ class RelGraphConv(nn.Module):
else:
weight = self.weight
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != th.int64 and self.low_mem:
etypes = th.unique(edges.data['type'])
msg = th.empty((edges.src['h'].shape[0], self.out_feat),
device=edges.src['h'].device)
for etype in etypes:
loc = edges.data['type'] == etype
w = weight[etype]
src = edges.src['h'][loc]
sub_msg = th.matmul(src, w)
msg[loc] = sub_msg
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
# Each element is the node's ID. Use index select: weight[etypes, h, :]
# The following is a faster version of it.
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * weight.shape[1] + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
msg.append(th.matmul(h_t[etype], weight[etype]))
msg = th.cat(msg)
else:
# put W_r into edges then do msg @ W_r
msg = utils.bmm_maybe_select(edges.src['h'], weight, edges.data['type'])
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze()
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if edges.src['h'].dtype == th.int64 and len(edges.src['h'].shape) == 1:
def bdd_message_func(self, edges, etypes):
"""Message function for block-diagonal-decomposition regularizer.
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
# calculate msg @ W_r before put msg into edge
if self.low_mem:
etypes = th.unique(edges.data['type'])
msg = th.empty((edges.src['h'].shape[0], self.out_feat),
device=edges.src['h'].device)
for etype in etypes:
loc = edges.data['type'] == etype
w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
src = edges.src['h'][loc].view(-1, self.num_bases, self.submat_in)
sub_msg = th.einsum('abc,bcd->abd', src, w)
sub_msg = sub_msg.reshape(-1, self.out_feat)
msg[loc] = sub_msg
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
tmp_w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in)
msg.append(th.einsum('abc,bcd->abd', tmp_h, tmp_w).reshape(-1, self.out_feat))
msg = th.cat(msg)
else:
weight = self.weight.index_select(0, edges.data['type']).view(
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = self.weight.index_select(0, etypes).view(
-1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in)
node = h.view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def forward(self, g, feat, etypes, norm=None):
"""
Description
-----------
Forward computation
"""Forward computation.
Parameters
----------
......@@ -260,8 +305,14 @@ class RelGraphConv(nn.Module):
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. It then treat the input feature as an one-hot encoding feature.
etypes : torch.Tensor
Edge type tensor. Shape: :math:`(|E|,)`
etypes : torch.Tensor or list[int]
Edge type data. Could be either
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
norm : torch.Tensor
Optional edge normalizer tensor. Shape: :math:`(|E|, 1)`.
......@@ -269,17 +320,44 @@ class RelGraphConv(nn.Module):
-------
torch.Tensor
New node features.
Notes
-----
Under the ``low_mem`` mode, DGL will sort the graph based on the edge types
and compute message passing one type at a time. DGL recommends sorts the
graph beforehand (and cache it if possible) and provides the integer list
format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API
to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True``
to it to get the ``etypes`` in integer list.
"""
if isinstance(etypes, th.Tensor):
if len(etypes) != g.num_edges():
raise DGLError('"etypes" tensor must have length equal to the number of edges'
' in the graph. But got {} and {}.'.format(
len(etypes), g.num_edges()))
if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1):
# Low-mem optimization is not enabled for node ID input. When enabled,
# it first sorts the graph based on the edge types (the sorting will not
# change the node IDs). It then converts the etypes tensor to an integer
# list, where each element is the number of edges of the type.
# Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, preserve_nodes=True)
# Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device)
etypes = (th.cat([pos[1:], num]) - pos).tolist()
with g.local_scope():
g.srcdata['h'] = feat
g.edata['type'] = etypes
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()],
self.loop_weight)
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
g.update_all(functools.partial(self.message_func, etypes=etypes),
fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.dstdata['h']
if self.layer_norm:
......@@ -292,3 +370,14 @@ class RelGraphConv(nn.Module):
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr
_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None)
def _searchsorted(sorted_sequence, values):
# searchsorted is introduced to PyTorch in 1.6.0
if _TORCH_HAS_SEARCHSORTED:
return th.searchsorted(sorted_sequence, values)
else:
device = values.device
return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(),
values.cpu().numpy())).to(device)
......@@ -258,7 +258,7 @@ def edge_subgraph(graph, edges, preserve_nodes=False):
--------
node_subgraph
"""
if graph.is_block:
if graph.is_block and not preserve_nodes:
raise DGLError('Extracting subgraph from a block graph is not allowed.')
if not isinstance(edges, Mapping):
assert len(graph.canonical_etypes) == 1, \
......
......@@ -1081,6 +1081,36 @@ def test_convert(idtype):
assert hg.device == g.device
assert g.number_of_nodes() == 5
@parametrize_dtype
def test_to_homo2(idtype):
# test the result homogeneous graph has nodes and edges sorted by their types
hg = create_test_heterograph(idtype)
g = dgl.to_homogeneous(hg)
ntypes = F.asnumpy(g.ndata[dgl.NTYPE])
etypes = F.asnumpy(g.edata[dgl.ETYPE])
p = 0
for tid, ntype in enumerate(hg.ntypes):
num_nodes = hg.num_nodes(ntype)
for i in range(p, p + num_nodes):
assert ntypes[i] == tid
p += num_nodes
p = 0
for tid, etype in enumerate(hg.canonical_etypes):
num_edges = hg.num_edges(etype)
for i in range(p, p + num_edges):
assert etypes[i] == tid
p += num_edges
# test store_type=False
g = dgl.to_homogeneous(hg, store_type=False)
assert dgl.NTYPE not in g.ndata
assert dgl.ETYPE not in g.edata
# test return_count=True
g, ntype_count, etype_count = dgl.to_homogeneous(hg, return_count=True)
for i, count in enumerate(ntype_count):
assert count == hg.num_nodes(hg.ntypes[i])
for i, count in enumerate(etype_count):
assert count == hg.num_edges(hg.canonical_etypes[i])
@parametrize_dtype
def test_metagraph_reachable(idtype):
g = create_test_heterograph(idtype)
......
......@@ -369,6 +369,87 @@ def test_rgcn():
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
def test_rgcn_sorted():
ctx = F.ctx()
etype = []
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
g = g.to(F.ctx())
# 5 etypes
R = 5
etype = [200, 200, 200, 200, 200]
B = 2
I = 10
O = 8
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = th.randn((100, I)).to(ctx)
r = etype
h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
rgc_bdd_low.weight = rgc_bdd.weight
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
h = th.randn((100, I)).to(ctx)
r = etype
h_new = rgc_bdd(g, h, r)
h_new_low = rgc_bdd_low(g, h, r)
assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# with norm
norm = th.zeros((g.number_of_edges(), 1)).to(ctx)
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = th.randn((100, I)).to(ctx)
r = etype
h_new = rgc_basis(g, h, r, norm)
h_new_low = rgc_basis_low(g, h, r, norm)
assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
rgc_bdd = nn.RelGraphConv(I, O, R, "bdd", B).to(ctx)
rgc_bdd_low = nn.RelGraphConv(I, O, R, "bdd", B, low_mem=True).to(ctx)
rgc_bdd_low.weight = rgc_bdd.weight
rgc_bdd_low.loop_weight = rgc_bdd.loop_weight
h = th.randn((100, I)).to(ctx)
r = etype
h_new = rgc_bdd(g, h, r, norm)
h_new_low = rgc_bdd_low(g, h, r, norm)
assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
# id input
rgc_basis = nn.RelGraphConv(I, O, R, "basis", B).to(ctx)
rgc_basis_low = nn.RelGraphConv(I, O, R, "basis", B, low_mem=True).to(ctx)
rgc_basis_low.weight = rgc_basis.weight
rgc_basis_low.w_comp = rgc_basis.w_comp
rgc_basis_low.loop_weight = rgc_basis.loop_weight
h = th.randint(0, I, (100,)).to(ctx)
r = etype
h_new = rgc_basis(g, h, r)
h_new_low = rgc_basis_low(g, h, r)
assert list(h_new.shape) == [100, O]
assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low)
@parametrize_dtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
def test_gat_conv(g, idtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment