"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bf8bb58f60863466e5254bfa6ee2ad15f2384acb"
Unverified Commit f95e9df3 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add seed_id for NumPickFn and PickFn in neighbor sampling to...

[Graphbolt] Add seed_id for NumPickFn and PickFn in neighbor sampling to support temporal sampling. (#6769)
parent c6abbb13
...@@ -325,10 +325,11 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph( ...@@ -325,10 +325,11 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* *
* @return A lambda function (int64_t offset, int64_t num_neighbors) -> * @return A lambda function (int64_t seed_offset, int64_t offset, int64_t
* torch::Tensor, which takes offset (the starting edge ID of the given node) * num_neighbors) -> torch::Tensor, which takes seed offset (the offset of the
* and num_neighbors (number of neighbors) as params and returns the pick number * seed to sample), offset (the starting edge ID of the given node) and
* of the given node. * num_neighbors (number of neighbors) as params and returns the pick number of
* the given node.
*/ */
auto GetNumPickFn( auto GetNumPickFn(
const std::vector<int64_t>& fanouts, bool replace, const std::vector<int64_t>& fanouts, bool replace,
...@@ -337,7 +338,7 @@ auto GetNumPickFn( ...@@ -337,7 +338,7 @@ auto GetNumPickFn(
// If fanouts.size() > 1, returns the total number of all edge types of the // If fanouts.size() > 1, returns the total number of all edge types of the
// given node. // given node.
return [&fanouts, replace, &probs_or_mask, &type_per_edge]( return [&fanouts, replace, &probs_or_mask, &type_per_edge](
int64_t offset, int64_t num_neighbors) { int64_t seed_offset, int64_t offset, int64_t num_neighbors) {
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
return NumPickByEtype( return NumPickByEtype(
fanouts, replace, type_per_edge.value(), probs_or_mask, offset, fanouts, replace, type_per_edge.value(), probs_or_mask, offset,
...@@ -365,11 +366,11 @@ auto GetNumPickFn( ...@@ -365,11 +366,11 @@ auto GetNumPickFn(
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments. * @param args Contains sampling algorithm specific arguments.
* *
* @return A lambda function: (int64_t offset, int64_t num_neighbors, * @return A lambda function: (int64_t seed_offset, int64_t offset, int64_t
* PickedType* picked_data_ptr) -> torch::Tensor, which takes offset (the * num_neighbors, PickedType* picked_data_ptr) -> torch::Tensor, which takes
* starting edge ID of the given node) and num_neighbors (number of neighbors) * seed_offset (the offset of the seed to sample), offset (the starting edge ID
* as params and puts the picked neighbors at the address specified by * of the given node) and num_neighbors (number of neighbors) as params and puts
* picked_data_ptr. * the picked neighbors at the address specified by picked_data_ptr.
*/ */
template <SamplerType S> template <SamplerType S>
auto GetPickFn( auto GetPickFn(
...@@ -378,7 +379,8 @@ auto GetPickFn( ...@@ -378,7 +379,8 @@ auto GetPickFn(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) { const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args]( return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t offset, int64_t num_neighbors, auto picked_data_ptr) { int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each // If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge // node; otherwise just sample once for each node with no regard of edge
// types. // types.
...@@ -444,7 +446,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -444,7 +446,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr[i + 1] = num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0 num_neighbors == 0
? 0 ? 0
: num_pick_fn(offset, num_neighbors); : num_pick_fn(i, offset, num_neighbors);
} }
}); });
...@@ -479,7 +481,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -479,7 +481,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const auto picked_offset = subgraph_indptr_data_ptr[i]; const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) { if (picked_number > 0) {
auto actual_picked_count = pick_fn( auto actual_picked_count = pick_fn(
offset, num_neighbors, i, offset, num_neighbors,
picked_eids_data_ptr + picked_offset); picked_eids_data_ptr + picked_offset);
TORCH_CHECK( TORCH_CHECK(
actual_picked_count == picked_number, actual_picked_count == picked_number,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment