Unverified Commit e34a5072 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Graphbolt][Feature] LABOR-0 implementation into sample_neighbors with tests (#5986)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent e7ff22f7
cmake_minimum_required(VERSION 3.5)
project(graphbolt C CXX)
set (CMAKE_CXX_STANDARD 17)
# Find PyTorch cmake files and PyTorch versions with the python interpreter
# $PYTHON_INTERP ("python3" or "python" if empty)
......
......@@ -16,6 +16,21 @@
namespace graphbolt {
namespace sampling {
enum SamplerType { NEIGHBOR, LABOR };
template <SamplerType S>
struct SamplerArgs;
template <>
struct SamplerArgs<SamplerType::NEIGHBOR> {};
template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
int64_t random_seed;
int64_t num_nodes;
};
/**
* @brief A sampling oriented csc format graph.
*
......@@ -143,6 +158,9 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
......@@ -156,7 +174,7 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const;
/**
......@@ -204,6 +222,13 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
const std::string& shared_memory_name);
private:
template <SamplerType S>
c10::intrusive_ptr<SampledSubgraph> SampleNeighborsImpl(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<S> args) const;
/**
* @brief Build a CSCSamplingGraph from shared memory tensors.
*
......@@ -298,10 +323,11 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*
* @return A tensor containing the picked neighbors.
*/
template <SamplerType S>
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
......@@ -330,11 +356,19 @@ torch::Tensor Pick(
*
* @return A tensor containing the picked neighbors.
*/
template <SamplerType S>
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <bool NonUniform, bool Replace, typename T = float>
torch::Tensor LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args);
} // namespace sampling
} // namespace graphbolt
......
......@@ -8,6 +8,8 @@
#include <graphbolt/serialize.h>
#include <torch/torch.h>
#include <cmath>
#include <limits>
#include <tuple>
#include <vector>
......@@ -129,22 +131,13 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
: torch::nullopt);
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
template <SamplerType S>
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool return_eids,
torch::optional<std::string> probs_name) const {
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<S> args) const {
const int64_t num_nodes = nodes.size(0);
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt;
if (probs_name.has_value() && !probs_name.value().empty()) {
probs_or_mask = edge_attributes_.value().at(probs_name.value());
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
}
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
bool consider_etype = (fanouts.size() > 1);
......@@ -176,11 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
if (consider_etype) {
picked_neighbors_per_node[i] = PickByEtype(
offset, num_neighbors, fanouts, replace, indptr_.options(),
type_per_edge_.value(), probs_or_mask);
type_per_edge_.value(), probs_or_mask, args);
} else {
picked_neighbors_per_node[i] = Pick(
offset, num_neighbors, fanouts[0], replace, indptr_.options(),
probs_or_mask);
probs_or_mask, args);
}
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
......@@ -206,6 +199,34 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const {
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt;
if (probs_name.has_value() && !probs_name.value().empty()) {
probs_or_mask = edge_attributes_.value().at(probs_name.value());
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
}
if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl(
nodes, fanouts, replace, return_eids, probs_or_mask, args);
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
nodes, fanouts, replace, return_eids, probs_or_mask, args);
}
}
std::tuple<torch::Tensor, torch::Tensor>
CSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
......@@ -423,10 +444,12 @@ inline torch::Tensor NonUniformPick(
return picked_neighbors;
}
torch::Tensor Pick(
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask) {
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args) {
if (probs_or_mask.has_value()) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask);
......@@ -435,11 +458,12 @@ torch::Tensor Pick(
}
}
template <SamplerType S>
torch::Tensor PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
std::vector<torch::Tensor> picked_neighbors(
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset;
......@@ -451,7 +475,7 @@ torch::Tensor PickByEtype(
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < fanouts.size(),
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
......@@ -460,9 +484,9 @@ torch::Tensor PickByEtype(
etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
if (fanout != 0) {
picked_neighbors[etype] = Pick(
picked_neighbors[etype] = Pick<S>(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask);
probs_or_mask, args);
}
etype_begin = etype_end;
}
......@@ -471,5 +495,207 @@ torch::Tensor PickByEtype(
return torch::cat(picked_neighbors, 0);
}
template <>
torch::Tensor Pick<SamplerType::LABOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) {
if (fanout == 0) return torch::tensor({}, options);
if (probs_or_mask.has_value()) {
torch::Tensor picked_neighbors;
AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) {
picked_neighbors = LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
} else {
picked_neighbors = LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
}
}));
return picked_neighbors;
} else if (replace) {
return LaborPick<false, true>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
} else { // replace = false
return LaborPick<false, false>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
}
}
template <typename T, typename U>
inline void safe_divide(T& a, U b) {
a = b > 0 ? (T)(a / b) : std::numeric_limits<T>::infinity();
}
/**
* @brief Perform uniform-nonuniform sampling of elements depending on the
* template parameter NonUniform and return the sampled indices.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1.
* - When the value is -1, all neighbors will be chosen for sampling. It is
* equivalent to selecting all neighbors with non-zero probability when the
* fanout is >= the number of neighbors (and replacement is set to false).
* - When the value is a non-negative integer, it serves as a minimum
* threshold for selecting neighbors.
* @param options Tensor options specifying the desired data type of the result.
* @param probs_or_mask Optional tensor containing the (unnormalized)
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* @param args Contains labor specific arguments.
*
* @return A tensor containing the picked neighbors.
*/
template <bool NonUniform, bool Replace, typename T>
inline torch::Tensor LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) {
fanout = fanout < 0 ? num_neighbors : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
return torch::arange(offset, offset + num_neighbors, options);
}
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion.
auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(
heap_tensor.data_ptr<int32_t>());
const T* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<T>() + offset : nullptr;
AT_DISPATCH_INTEGRAL_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data =
args.indices.data_ptr<scalar_t>() + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
// smallest fanout of them. Implements arXiv:2210.13339 Section A.3.
// Unlike sampling without replacement below, the same item can be
// included fanout times in our sample. Thus, we sort and pick the
// smallest fanout random numbers out of num_neighbors * fanout of
// them. Each item has fanout many random numbers in the race and the
// smallest fanout of them get picked. Instead of generating
// fanout * num_neighbors random numbers and increase the complexity,
// I devised an algorithm to generate the fanout numbers for an item
// in a sorted manner on demand, meaning we continue generating random
// numbers for an item only if it has been sampled that many times
// already.
// https://gist.github.com/mfbalin/096dcad5e3b1f6a59ff7ff2f9f541618
//
// [Complexity Analysis]
// Will modify the heap at most linear in O(num_neighbors + fanout)
// and each modification takes O(log(fanout)). So the total complexity
// is O((fanout + num_neighbors) log(fanout)). It is possible to
// decrease the logarithmic factor down to
// O(log(min(fanout, num_neighbors))).
torch::Tensor remaining =
torch::ones({num_neighbors}, torch::kFloat32);
float* rem_data = remaining.data_ptr<float>();
auto heap_end = heap_data;
const auto init_count = (num_neighbors + fanout - 1) / num_neighbors;
auto sample_neighbor_i_with_index_t_jth_time =
[&](scalar_t t, int64_t j, uint32_t i) {
auto rnd = labor::jth_sorted_uniform_random(
args.random_seed, t, args.num_nodes, j, rem_data[i],
fanout - j); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
if (heap_end < heap_data + fanout) {
heap_end[0] = std::make_pair(rnd, i);
std::push_heap(heap_data, ++heap_end);
return false;
} else if (rnd < heap_data[0].first) {
std::pop_heap(heap_data, heap_data + fanout);
heap_data[fanout - 1] = std::make_pair(rnd, i);
std::push_heap(heap_data, heap_data + fanout);
return false;
} else {
rem_data[i] = -1;
return true;
}
};
for (uint32_t i = 0; i < num_neighbors; ++i) {
for (int64_t j = 0; j < init_count; j++) {
const auto t = local_indices_data[i];
sample_neighbor_i_with_index_t_jth_time(t, j, i);
}
}
for (uint32_t i = 0; i < num_neighbors; ++i) {
if (rem_data[i] == -1) continue;
const auto t = local_indices_data[i];
for (int64_t j = init_count; j < fanout; ++j) {
if (sample_neighbor_i_with_index_t_jth_time(t, j, i)) break;
}
}
} else {
// [Algorithm]
// Use a max-heap to get rid of the big random numbers and filter the
// smallest fanout of them. Implements arXiv:2210.13339 Section A.3.
//
// [Complexity Analysis]
// the first for loop and std::make_heap runs in time O(fanouts).
// The next for loop compares each random number to the current
// minimum fanout numbers. For any given i, the probability that the
// current random number will replace any number in the heap is fanout
// / i. Summing from i=fanout to num_neighbors, we get f * (H_n -
// H_f), where n is num_neighbors and f is fanout, H_f is \sum_j=1^f
// 1/j. In the end H_n - H_f = O(log n/f), there are n - f iterations,
// each heap operation takes time log f, so the total complexity is
// O(f + (n - f)
// + f log(n/f) log f) = O(n + f log(f) log(n/f)). If f << n (f is a
// constant in almost all cases), then the average complexity is
// O(num_neighbors).
for (uint32_t i = 0; i < fanout; ++i) {
const auto t = local_indices_data[i];
auto rnd =
labor::uniform_random<float>(args.random_seed, t); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
heap_data[i] = std::make_pair(rnd, i);
}
if (!NonUniform || fanout < num_neighbors) {
std::make_heap(heap_data, heap_data + fanout);
}
for (uint32_t i = fanout; i < num_neighbors; ++i) {
const auto t = local_indices_data[i];
auto rnd =
labor::uniform_random<float>(args.random_seed, t); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
if (rnd < heap_data[0].first) {
std::pop_heap(heap_data, heap_data + fanout);
heap_data[fanout - 1] = std::make_pair(rnd, i);
std::push_heap(heap_data, heap_data + fanout);
}
}
}
}));
int64_t num_sampled = 0;
torch::Tensor picked_neighbors = torch::empty({fanout}, options);
AT_DISPATCH_INTEGRAL_TYPES(
picked_neighbors.scalar_type(), "LaborPickOutput", ([&] {
scalar_t* picked_neighbors_data = picked_neighbors.data_ptr<scalar_t>();
for (int64_t i = 0; i < fanout; ++i) {
const auto [rnd, j] = heap_data[i];
if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) {
picked_neighbors_data[num_sampled++] = offset + j;
}
}
}));
TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!");
return picked_neighbors.narrow(0, 0, num_sampled);
}
} // namespace sampling
} // namespace graphbolt
......@@ -72,6 +72,33 @@ class RandomEngine {
private:
pcg32 rng_;
};
namespace labor {
template <typename T>
inline T uniform_random(int64_t random_seed, int64_t t) {
pcg32 ng(random_seed, t);
std::uniform_real_distribution<T> uni;
return uni(ng);
}
template <typename T>
inline T invcdf(T u, int64_t n, T rem) {
constexpr T one = 1;
return rem * (one - std::pow(one - u, one / n));
}
template <typename T>
inline T jth_sorted_uniform_random(
int64_t random_seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const auto u = uniform_random<T>(random_seed, t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
return 1 - rem;
}
}; // namespace labor
} // namespace graphbolt
#endif // GRAPHBOLT_RANDOM_H_
......@@ -248,6 +248,13 @@ class CSCSamplingGraph:
node_pairs[etype] = (hetero_row, hetero_column)
return SampledSubgraphImpl(node_pairs=node_pairs)
def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = []
for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
def sample_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
......@@ -316,16 +323,8 @@ class CSCSamplingGraph:
defaultdict(<class 'list'>, {('n1', 'e1', 'n2'): (tensor([2]), \
tensor([1])), ('n1', 'e2', 'n3'): (tensor([3]), tensor([2]))})
"""
def convert_to_homogeneous_nodes(nodes):
homogeneous_nodes = []
for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
if isinstance(nodes, dict):
nodes = convert_to_homogeneous_nodes(nodes)
nodes = self._convert_to_homogeneous_nodes(nodes)
C_sampled_subgraph = self._sample_neighbors(
nodes, fanouts, replace, False, probs_name
......@@ -333,6 +332,44 @@ class CSCSamplingGraph:
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
], "Fanouts should have the same number of elements as etypes or \
should have a length of 1."
if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
), "To perform sampling for each edge type (when the length of \
`fanouts` > 1), the graph must include edge type information."
assert torch.all(
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if probs_name:
assert (
probs_name in self.edge_attributes
), f"Unknown edge attribute '{probs_name}'."
probs_or_mask = self.edge_attributes[probs_name]
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.num_edges
), "Probs should have the same number of elements as the number \
of edges."
assert probs_or_mask.dtype in [
torch.bool,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
], "Probs should have a floating-point or boolean data type."
def _sample_neighbors(
self,
nodes: torch.Tensor,
......@@ -384,46 +421,75 @@ class CSCSamplingGraph:
The sampled C subgraph.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
], "Fanouts should have the same number of elements as etypes or \
should have a length of 1."
if fanouts.size(0) > 1:
assert (
self.type_per_edge is not None
), "To perform sampling for each edge type (when the length of \
`fanouts` > 1), the graph must include edge type information."
assert torch.all(
(fanouts >= 0) | (fanouts == -1)
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
if probs_name:
assert (
probs_name in self.edge_attributes
), f"Unknown edge attribute '{probs_name}'."
probs_or_mask = self.edge_attributes[probs_name]
assert probs_or_mask.dim() == 1, "Probs should be 1-D tensor."
assert (
probs_or_mask.size(0) == self.num_edges
), "Probs should have the same number of elements as the number \
of edges."
assert probs_or_mask.dtype in [
torch.bool,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
], "Probs should have a floating-point or boolean data type."
self._check_sampler_arguments(nodes, fanouts, probs_name)
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, return_eids, probs_name
nodes, fanouts.tolist(), replace, False, return_eids, probs_name
)
def sample_layer_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from arXiv:2210.13339:
"Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs"
Parameters
----------
nodes: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
- If `nodes` is a dictionary: The keys should be node type and
ids inside are heterogeneous ids.
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors will be chosen for
sampling. It is equivalent to selecting all neighbors when
the fanout is >= the number of neighbors (and replace is set to
false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
-------
SampledSubgraphImpl
The sampled subgraph.
Examples
--------
TODO: Provide typical examples.
"""
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
self._check_sampler_arguments(nodes, fanouts, probs_name)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace, True, False, probs_name
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio
):
......
......@@ -401,7 +401,8 @@ def test_sample_neighbors_homo():
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors_hetero():
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
"""Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0]
......@@ -436,7 +437,8 @@ def test_sample_neighbors_hetero():
# Generate subgraph via sample neighbors.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.tensor([-1, -1])
subgraph = graph.sample_neighbors(nodes, fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
expected_node_pairs = {
......@@ -478,8 +480,9 @@ def test_sample_neighbors_hetero():
([-1, -1], 2, 2),
],
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
......@@ -514,7 +517,8 @@ def test_sample_neighbors_fanouts(
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
fanouts = torch.LongTensor(fanouts)
subgraph = graph.sample_neighbors(nodes, fanouts)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
assert (
......@@ -590,8 +594,9 @@ def test_sample_neighbors_replace(
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs(replace, probs_name):
def test_sample_neighbors_probs(replace, labor, probs_name):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -619,7 +624,9 @@ def test_sample_neighbors_probs(replace, probs_name):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([2]),
replace=replace,
......@@ -639,6 +646,7 @@ def test_sample_neighbors_probs(replace, probs_name):
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize(
"probs_or_mask",
[
......@@ -646,7 +654,7 @@ def test_sample_neighbors_probs(replace, probs_name):
torch.zeros(12, dtype=torch.bool),
],
)
def test_sample_neighbors_zero_probs(replace, probs_or_mask):
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data.
num_nodes = 5
num_edges = 12
......@@ -662,7 +670,8 @@ def test_sample_neighbors_zero_probs(replace, probs_or_mask):
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(
nodes,
fanouts=torch.tensor([5]),
replace=replace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment