"graphbolt/vscode:/vscode.git/clone" did not exist on "1547bd931d17cd1da144a6d38bb687c0f2c3b364"
Unverified Commit f95e9df3 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add seed_id for NumPickFn and PickFn in neighbor sampling to...

[Graphbolt] Add seed_id for NumPickFn and PickFn in neighbor sampling to support temporal sampling. (#6769)
parent c6abbb13
......@@ -325,10 +325,11 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A lambda function (int64_t offset, int64_t num_neighbors) ->
* torch::Tensor, which takes offset (the starting edge ID of the given node)
* and num_neighbors (number of neighbors) as params and returns the pick number
* of the given node.
* @return A lambda function (int64_t seed_offset, int64_t offset, int64_t
* num_neighbors) -> torch::Tensor, which takes seed offset (the offset of the
* seed to sample), offset (the starting edge ID of the given node) and
* num_neighbors (number of neighbors) as params and returns the pick number of
* the given node.
*/
auto GetNumPickFn(
const std::vector<int64_t>& fanouts, bool replace,
......@@ -337,7 +338,7 @@ auto GetNumPickFn(
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return [&fanouts, replace, &probs_or_mask, &type_per_edge](
int64_t offset, int64_t num_neighbors) {
int64_t seed_offset, int64_t offset, int64_t num_neighbors) {
if (fanouts.size() > 1) {
return NumPickByEtype(
fanouts, replace, type_per_edge.value(), probs_or_mask, offset,
......@@ -365,11 +366,11 @@ auto GetNumPickFn(
* equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments.
*
* @return A lambda function: (int64_t offset, int64_t num_neighbors,
* PickedType* picked_data_ptr) -> torch::Tensor, which takes offset (the
* starting edge ID of the given node) and num_neighbors (number of neighbors)
* as params and puts the picked neighbors at the address specified by
* picked_data_ptr.
* @return A lambda function: (int64_t seed_offset, int64_t offset, int64_t
* num_neighbors, PickedType* picked_data_ptr) -> torch::Tensor, which takes
* seed_offset (the offset of the seed to sample), offset (the starting edge ID
* of the given node) and num_neighbors (number of neighbors) as params and puts
* the picked neighbors at the address specified by picked_data_ptr.
*/
template <SamplerType S>
auto GetPickFn(
......@@ -378,7 +379,8 @@ auto GetPickFn(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t offset, int64_t num_neighbors, auto picked_data_ptr) {
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
......@@ -444,7 +446,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0
? 0
: num_pick_fn(offset, num_neighbors);
: num_pick_fn(i, offset, num_neighbors);
}
});
......@@ -479,7 +481,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
offset, num_neighbors,
i, offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment