Unverified Commit bc14829f authored by mszarma's avatar mszarma Committed by GitHub
Browse files

[Feature] Exclude edges in sample_neighbors (#2971)



* [Feature] Exclude edges in sample_neighbors

Extending sample_neighbors and sample_frontier
API to support exclude_edges parameter.

exclude_edges support tensor and dict data
Feature enable excluding certain edges
during neighborhood sampling
Exclude_edges contains EID's of edges
which will be excluded
during neighbor picking for seed nodes.

Added test case for heterograph and homograph
RFC issue id: 2944

* compatibility

* fix

* fix
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent ac9261b2
......@@ -30,6 +30,8 @@ namespace sampling {
* \param dir Edge direction.
* \param probability A vector of 1D float arrays, indicating the transition probability of
* each edge by edge type. An empty float array assumes uniform transition.
* \param exclude_edges Edges IDs of each type which will be excluded during sampling.
* The vector length must be equal to the number of edges types. Empty array is allowed.
* \param replace If true, sample with replacement.
* \return Sampled neighborhoods as a graph. The return graph has the same schema as the
* original one.
......@@ -40,6 +42,7 @@ HeteroSubgraph SampleNeighbors(
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& probability,
const std::vector<IdArray>& exclude_edges,
bool replace = true);
/*!
......
......@@ -238,6 +238,14 @@ class BlockSampler(object):
a CUDA context if multiprocessing is not used in the dataloader (e.g.,
num_workers is 0). If this is None, the sampled blocks will be stored
on the same device as the input graph.
exclude_edges_in_frontier : bool, default False
If True, the :func:`sample_frontier` method will receive an argument
:attr:`exclude_eids` containing the edge IDs from the original graph to exclude.
The :func:`sample_frontier` method must return a graph that does not contain
the edges corresponding to the excluded edges. No additional postprocessing
will be done.
Otherwise, the edges will be removed *after* :func:`sample_frontier` returns.
Notes
-----
......@@ -250,7 +258,37 @@ class BlockSampler(object):
self.return_eids = return_eids
self.set_output_context(output_ctx)
def sample_frontier(self, block_id, g, seed_nodes):
# This is really a hack working around the lack of GPU-based neighbor sampling
# with edge exclusion.
@classmethod
def exclude_edges_in_frontier(cls, g):
"""Returns whether the sampler will exclude edges in :func:`sample_frontier`.
If this method returns True, the method :func:`sample_frontier` will receive an
argument :attr:`exclude_eids` from :func:`sample_blocks`. :func:`sample_frontier`
is then responsible for removing those edges.
If this method returns False, :func:`sample_blocks` will be responsible for
removing the edges.
When subclassing :class:`BlockSampler`, this method should return True when you
would like to remove the excluded edges in your :func:`sample_frontier` method.
By default this method returns False.
Parameters
----------
g : DGLGraph
The original graph
Returns
-------
bool
Whether :func:`sample_frontier` will receive an argument :attr:`exclude_eids`.
"""
return False
def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None):
"""Generate the frontier given the destination nodes.
The subclasses should override this function.
......@@ -266,6 +304,11 @@ class BlockSampler(object):
If the graph only has one node type, one can just specify a single tensor
of node IDs.
exclude_eids: Tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes.
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
Returns
-------
......@@ -307,7 +350,6 @@ class BlockSampler(object):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
blocks = []
eid_excluder = _create_eid_excluder(exclude_eids, self.output_device)
if isinstance(g, DistGraph):
# TODO:(nv-dlasalle) dist graphs may not have an associated graph,
......@@ -324,7 +366,12 @@ class BlockSampler(object):
for ntype, nodes in seed_nodes_in.items()}
else:
seed_nodes_in = seed_nodes_in.to(graph_device)
frontier = self.sample_frontier(block_id, g, seed_nodes_in)
if self.exclude_edges_in_frontier:
frontier = self.sample_frontier(
block_id, g, seed_nodes_in, exclude_eids=exclude_eids)
else:
frontier = self.sample_frontier(block_id, g, seed_nodes_in)
if self.output_device is not None:
frontier = frontier.to(self.output_device)
......@@ -338,8 +385,10 @@ class BlockSampler(object):
# Removing edges from the frontier for link prediction training falls
# into the category of frontier postprocessing
if eid_excluder is not None:
eid_excluder(frontier)
if not self.exclude_edges_in_frontier:
eid_excluder = _create_eid_excluder(exclude_eids, self.output_device)
if eid_excluder is not None:
eid_excluder(frontier)
block = transform.to_block(frontier, seed_nodes_out)
if self.return_eids:
......
......@@ -71,7 +71,11 @@ class MultiLayerNeighborSampler(BlockSampler):
self.fanout_arrays = []
self.prob_arrays = None
def sample_frontier(self, block_id, g, seed_nodes):
@classmethod
def exclude_edges_in_frontier(cls, g):
return not isinstance(g, distributed.DistGraph) and g.device == F.cpu()
def sample_frontier(self, block_id, g, seed_nodes, exclude_eids=None):
fanout = self.fanouts[block_id]
if isinstance(g, distributed.DistGraph):
if fanout is None:
......@@ -97,7 +101,7 @@ class MultiLayerNeighborSampler(BlockSampler):
frontier = sampling.sample_neighbors(
g, seed_nodes, self.fanout_arrays[block_id],
replace=self.replace, prob=self.prob_arrays)
replace=self.replace, prob=self.prob_arrays, exclude_edges=exclude_eids)
return frontier
def _build_prob_arrays(self, g):
......
......@@ -144,7 +144,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
return ret
def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
copy_ndata=True, copy_edata=True, _dist_training=False):
copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None):
"""Sample neighboring edges of the given nodes and return the induced subgraph.
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
......@@ -186,6 +186,11 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
to sum up to one). Otherwise, the result will be undefined.
If :attr:`prob` is not None, GPU sampling is not supported.
exclude_edges: tensor or dict
Edge IDs to exclude during sampling neighbors for the seed nodes.
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
If a single tensor is given, the graph must only have one type of nodes.
replace : bool, optional
If True, sample with replacement.
copy_ndata: bool, optional
......@@ -249,6 +254,30 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3)
>>> sg.edges(order='eid')
(tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))
To exclude certain EID's during sampling for the seed nodes:
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
>>> g_edges = g.all_edges(form='all')``
(tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2])
>>> sg.all_edges(form='all')
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3])
tensor([False, False, False])
>>> g = dgl.heterograph({
... ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]),
... ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]),
... ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])})
>>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5]))
>>> excluded_edges = {('drug', 'interacts', 'drug'): g_edges[2][:3]}
>>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges)
>>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug'))
(tensor([2, 1]), tensor([0, 1]), tensor([0, 1]))
>>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug'))
tensor([False, False, False])
"""
if not isinstance(nodes, dict):
if len(g.ntypes) > 1:
......@@ -290,8 +319,21 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False,
else:
prob_arrays.append(nd.array([], ctx=nd.cpu()))
excluded_edges_all_t = []
if exclude_edges is not None:
if not isinstance(exclude_edges, dict):
if len(g.etypes) > 1:
raise DGLError("Must specify etype type when the graph is not homogeneous.")
exclude_edges = {g.canonical_etypes[0] : exclude_edges}
exclude_edges = utils.prepare_tensor_dict(g, exclude_edges, 'edges')
for etype in g.canonical_etypes:
if etype in exclude_edges:
excluded_edges_all_t.append(F.to_dgl_nd(exclude_edges[etype]))
else:
excluded_edges_all_t.append(nd.array([], ctx=nd.cpu()))
subgidx = _CAPI_DGLSampleNeighbors(g._graph, nodes_all_types, fanout_array,
edge_dir, prob_arrays, replace)
edge_dir, prob_arrays, excluded_edges_all_t, replace)
induced_edges = subgidx.induced_edges
ret = DGLHeteroGraph(subgidx.graph, g.ntypes, g.etypes)
......
......@@ -17,12 +17,57 @@ using namespace dgl::aten;
namespace dgl {
namespace sampling {
HeteroSubgraph ExcludeCertainEdges(
const HeteroSubgraph& sg,
const std::vector<IdArray>& exclude_edges) {
HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
IdArray edge_ids = Range(0,
sg.induced_edges[etype]->shape[0],
sg.induced_edges[etype]->dtype.bits,
sg.induced_edges[etype]->ctx);
if (exclude_edges[etype].GetSize() == 0) {
remain_edges[etype] = edge_ids;
remain_induced_edges[etype] = sg.induced_edges[etype];
continue;
}
ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
IdType* idx_data = edge_ids.Ptr<IdType>();
IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
std::sort(exclude_edges[etype].Ptr<IdType>(),
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
IdType outId = 0;
for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
if (!std::binary_search(exclude_edges_data,
exclude_edges_data + exclude_edges_len,
induced_edges_data[i])) {
induced_edges_data[outId] = induced_edges_data[i];
idx_data[outId] = idx_data[i];
++outId;
}
}
remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
remain_induced_edges[etype] = aten::IndexSelect(sg.induced_edges[etype], 0, outId);
});
}
HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
subg.induced_edges = std::move(remain_induced_edges);
return subg;
}
HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& prob,
const std::vector<IdArray>& exclude_edges,
bool replace) {
// sanity check
......@@ -101,6 +146,9 @@ HeteroSubgraph SampleNeighbors(
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
if (!exclude_edges.empty()) {
return ExcludeCertainEdges(ret, exclude_edges);
}
return ret;
}
......@@ -382,7 +430,8 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
const auto& fanouts = fanouts_array.ToVector<int64_t>();
const std::string dir_str = args[3];
const auto& prob = ListValueToVector<FloatArray>(args[4]);
const bool replace = args[5];
const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
const bool replace = args[6];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
......@@ -390,7 +439,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighbors(
hg.sptr(), nodes, fanouts, dir, prob, replace);
hg.sptr(), nodes, fanouts, dir, prob, exclude_edges, replace);
*rv = HeteroSubgraphRef(subg);
});
......
......@@ -3,6 +3,7 @@ import backend as F
import numpy as np
import unittest
from collections import defaultdict
import pytest
def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
traces = F.asnumpy(traces)
......@@ -819,6 +820,98 @@ def test_sample_neighbors_etype_homogeneous():
csc_g, seeds, dgl.ETYPE, 5, edge_dir='out', replace=True)
check_num2(subg.edges()[0], True)
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_exclude_edges_heteroG(dtype):
d_i_d_u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300, size=100, dtype=dtype)))
d_i_d_v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=d_i_d_u_nodes.shape, dtype=dtype))
d_i_g_u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300, size=100, dtype=dtype)))
d_i_g_v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=d_i_g_u_nodes.shape, dtype=dtype))
d_t_d_u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300, size=100, dtype=dtype)))
d_t_d_v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=d_t_d_u_nodes.shape, dtype=dtype))
g = dgl.heterograph({
('drug', 'interacts', 'drug'): (d_i_d_u_nodes, d_i_d_v_nodes),
('drug', 'interacts', 'gene'): (d_i_g_u_nodes, d_i_g_v_nodes),
('drug', 'treats', 'disease'): (d_t_d_u_nodes, d_t_d_v_nodes)
})
(U, V, EID) = (0, 1, 2)
nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
did_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
did_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
sampled_amount = np.random.randint(low=1, high=10, dtype=dtype)
drug_i_drug_edges = g.all_edges(form='all', etype=('drug','interacts','drug'))
excluded_d_i_d_edges = drug_i_drug_edges[EID][did_b_idx:did_e_idx]
sampled_drug_node = drug_i_drug_edges[V][nd_b_idx:nd_e_idx]
did_excluded_nodes_U = drug_i_drug_edges[U][did_b_idx:did_e_idx]
did_excluded_nodes_V = drug_i_drug_edges[V][did_b_idx:did_e_idx]
nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
dig_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
dig_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
drug_i_gene_edges = g.all_edges(form='all', etype=('drug','interacts','gene'))
excluded_d_i_g_edges = drug_i_gene_edges[EID][dig_b_idx:dig_e_idx]
dig_excluded_nodes_U = drug_i_gene_edges[U][dig_b_idx:dig_e_idx]
dig_excluded_nodes_V = drug_i_gene_edges[V][dig_b_idx:dig_e_idx]
sampled_gene_node = drug_i_gene_edges[V][nd_b_idx:nd_e_idx]
nd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
nd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
dtd_b_idx = np.random.randint(low=1, high=24, dtype=dtype)
dtd_e_idx = np.random.randint(low=25, high=49, dtype=dtype)
drug_t_dis_edges = g.all_edges(form='all', etype=('drug','treats','disease'))
excluded_d_t_d_edges = drug_t_dis_edges[EID][dtd_b_idx:dtd_e_idx]
dtd_excluded_nodes_U = drug_t_dis_edges[U][dtd_b_idx:dtd_e_idx]
dtd_excluded_nodes_V = drug_t_dis_edges[V][dtd_b_idx:dtd_e_idx]
sampled_disease_node = drug_t_dis_edges[V][nd_b_idx:nd_e_idx]
excluded_edges = {('drug', 'interacts', 'drug'): excluded_d_i_d_edges,
('drug', 'interacts', 'gene'): excluded_d_i_g_edges,
('drug', 'treats', 'disease'): excluded_d_t_d_edges
}
sg = dgl.sampling.sample_neighbors(g, {'drug': sampled_drug_node,
'gene': sampled_gene_node,
'disease': sampled_disease_node},
sampled_amount, exclude_edges=excluded_edges)
assert not np.any(F.asnumpy(sg.has_edges_between(did_excluded_nodes_U,did_excluded_nodes_V,
etype=('drug','interacts','drug'))))
assert not np.any(F.asnumpy(sg.has_edges_between(dig_excluded_nodes_U,dig_excluded_nodes_V,
etype=('drug','interacts','gene'))))
assert not np.any(F.asnumpy(sg.has_edges_between(dtd_excluded_nodes_U,dtd_excluded_nodes_V,
etype=('drug','treats','disease'))))
@pytest.mark.parametrize('dtype', ['int32', 'int64'])
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_exclude_edges_homoG(dtype):
u_nodes = F.zerocopy_from_numpy(np.unique(np.random.randint(300,size=100, dtype=dtype)))
v_nodes = F.zerocopy_from_numpy(np.random.randint(25, size=u_nodes.shape, dtype=dtype))
g = dgl.graph((u_nodes, v_nodes))
(U, V, EID) = (0, 1, 2)
nd_b_idx = np.random.randint(low=1,high=24, dtype=dtype)
nd_e_idx = np.random.randint(low=25,high=49, dtype=dtype)
b_idx = np.random.randint(low=1,high=24, dtype=dtype)
e_idx = np.random.randint(low=25,high=49, dtype=dtype)
sampled_amount = np.random.randint(low=1,high=10, dtype=dtype)
g_edges = g.all_edges(form='all')
excluded_edges = g_edges[EID][b_idx:e_idx]
sampled_node = g_edges[V][nd_b_idx:nd_e_idx]
excluded_nodes_U = g_edges[U][b_idx:e_idx]
excluded_nodes_V = g_edges[V][b_idx:e_idx]
sg = dgl.sampling.sample_neighbors(g, sampled_node,
sampled_amount, exclude_edges=excluded_edges)
assert not np.any(F.asnumpy(sg.has_edges_between(excluded_nodes_U,excluded_nodes_V)))
if __name__ == '__main__':
test_sample_neighbors_etype_homogeneous()
......@@ -831,3 +924,5 @@ if __name__ == '__main__':
test_sample_neighbors_with_0deg()
test_sample_neighbors_biased_homogeneous()
test_sample_neighbors_biased_bipartite()
test_sample_neighbors_exclude_edges_heteroG('int32')
test_sample_neighbors_exclude_edges_homoG('int32')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment