You need to sign in or sign up before continuing.
Unverified Commit a1a3ce89 authored by Jing Zhu's avatar Jing Zhu Committed by GitHub
Browse files

[SpotTarget]Edge sampler with excluding edges adjacent to low-degree nodes (#5893)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarZYH <31127200+Tonyzhou98@users.noreply.github.com>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent b53f9365
...@@ -9,5 +9,6 @@ from .neighbor_sampler import * ...@@ -9,5 +9,6 @@ from .neighbor_sampler import *
from .shadow import * from .shadow import *
if F.get_preferred_backend() == "pytorch": if F.get_preferred_backend() == "pytorch":
from .spot_target import *
from .dataloader import * from .dataloader import *
from .dist_dataloader import * from .dist_dataloader import *
...@@ -328,7 +328,7 @@ def find_exclude_eids( ...@@ -328,7 +328,7 @@ def find_exclude_eids(
---------- ----------
g : DGLGraph g : DGLGraph
The graph. The graph.
exclude_mode : str, optional exclude :
Can be either of the following, Can be either of the following,
None (default) None (default)
...@@ -535,7 +535,7 @@ def as_edge_prediction_sampler( ...@@ -535,7 +535,7 @@ def as_edge_prediction_sampler(
edge IDs to exclude from neighborhood. The argument will be either a tensor edge IDs to exclude from neighborhood. The argument will be either a tensor
for homogeneous graphs or a dict of edge types and tensors for heterogeneous for homogeneous graphs or a dict of edge types and tensors for heterogeneous
graphs. graphs.
exclude : str, optional exclude : Union[str, callable], optional
Whether and how to exclude dependencies related to the sampled edges in the Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are minibatch. Possible values are
......
"""SpotTarget: Target edge excluder for link prediction"""
import torch
from .base import find_exclude_eids
class SpotTarget(object):
"""Callable excluder object to exclude the edges by the degree threshold.
Besides excluding all the edges or given edges in the edge sampler
``dgl.dataloading.as_edge_prediction_sampler`` in link prediction training,
this excluder can extend the exclusion function by only excluding the edges incident
to low-degree nodes in the graph to bring the performance increase in training
link prediction model. This function will exclude the edge if incident to at least
one node with degree larger or equal to ``degree_threshold``. The performance
boost by excluding the target edges incident to low-degree nodes can be found
in this paper: https://arxiv.org/abs/2306.00899
Parameters
----------
g : DGLGraph
The graph.
exclude : Union[str, callable]
Whether and how to exclude dependencies related to the sampled edges in the
minibatch. Possible values are
* ``self``, for excluding the edges in the current minibatch.
* ``reverse_id``, for excluding not only the edges in the current minibatch but
also their reverse edges according to the ID mapping in the argument
:attr:`reverse_eids`.
* ``reverse_types``, for excluding not only the edges in the current minibatch
but also their reverse edges stored in another type according to
the argument :attr:`reverse_etypes`.
* User-defined exclusion rule. It is a callable with edges in the current
minibatch as a single argument and should return the edges to be excluded.
degree_threshold : int
The threshold of node degrees, if the source or target node of an edge incident to
has larger or equal degrees than ``degree_threshold``, this edge will be excluded from
the graph
reverse_eids : Tensor or dict[etype, Tensor], optional
A tensor of reverse edge ID mapping. The i-th element indicates the ID of
the i-th edge's reverse edge.
If the graph is heterogeneous, this argument requires a dictionary of edge
types and the reverse edge ID mapping tensors.
reverse_etypes : dict[etype, etype], optional
The mapping from the original edge types to their reverse edge types.
Examples
--------
.. code:: python
low_degree_excluder = SpotTarget(g, degree_threshold=10)
sampler = as_edge_prediction_sampler(sampler, exclude=low_degree_excluder,
reverse_eids=reverse_eids, negative_sampler=negative_sampler.Uniform(1))
"""
def __init__(
self,
g,
exclude,
degree_threshold=10,
reverse_eids=None,
reverse_etypes=None,
):
self.g = g
self.exclude = exclude
self.degree_threshold = degree_threshold
self.reverse_eids = reverse_eids
self.reverse_etypes = reverse_etypes
def __call__(self, seed_edges):
g = self.g
src, dst = g.find_edges(seed_edges)
head_degree = g.in_degrees(src)
tail_degree = g.in_degrees(dst)
degree = torch.min(head_degree, tail_degree)
degree_mask = degree < self.degree_threshold
edges_need_to_exclude = seed_edges[degree_mask]
return find_exclude_eids(
g,
edges_need_to_exclude,
self.exclude,
self.reverse_eids,
self.reverse_etypes,
)
from collections.abc import Mapping
import dgl
import numpy as np
import pytest
import torch
def _create_homogeneous():
s = torch.randint(0, 200, (1000,))
d = torch.randint(0, 200, (1000,))
g = dgl.graph((s, d), num_nodes=200)
reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)])
seed_edges = torch.arange(0, 1000)
return g, reverse_eids, seed_edges
def _find_edges_to_exclude(g, pair_eids, degree_threshold):
src, dst = g.find_edges(pair_eids)
head_degree = g.in_degrees(src)
tail_degree = g.in_degrees(dst)
degree = torch.min(head_degree, tail_degree)
degree_mask = degree < degree_threshold
low_degree_pair_eids = pair_eids[degree_mask]
low_degree_pair_eids = torch.cat(
[low_degree_pair_eids, low_degree_pair_eids + 1000]
)
return low_degree_pair_eids
@pytest.mark.parametrize("degree_threshold", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("batch_size", [1, 10, 50])
def test_spot_target_excludes(degree_threshold, batch_size):
g, reverse_eids, seed_edges = _create_homogeneous()
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
low_degree_excluder = dgl.dataloading.SpotTarget(
g,
exclude="reverse_id",
degree_threshold=degree_threshold,
reverse_eids=reverse_eids,
)
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler,
exclude=low_degree_excluder,
negative_sampler=dgl.dataloading.negative_sampler.Uniform(1),
)
dataloader = dgl.dataloading.DataLoader(
g, seed_edges, sampler, batch_size=batch_size
)
for i, (input_nodes, pair_graph, neg_pair_graph, blocks) in enumerate(
dataloader
):
if isinstance(blocks, list):
subg = blocks[0]
else:
subg = blocks
pair_eids = pair_graph.edata[dgl.EID]
block_eids = subg.edata[dgl.EID]
edges_to_exclude = _find_edges_to_exclude(
g, pair_eids, degree_threshold
)
if edges_to_exclude is None:
continue
edges_to_exclude = dgl.utils.recursive_apply(
edges_to_exclude, lambda x: x.cpu().numpy()
)
block_eids = dgl.utils.recursive_apply(
block_eids, lambda x: x.cpu().numpy()
)
if isinstance(edges_to_exclude, Mapping):
for k in edges_to_exclude.keys():
assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
else:
assert not np.isin(edges_to_exclude, block_eids).any()
if i == 10:
break
if __name__ == "__main__":
test_spot_target_excludes(degree_threshold=2, batch_size=10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment