"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "5b4f79d9ba8cbeeb8d6f0fbba3ba5757b718888b"
Unverified Commit 58d98e03 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Utilize pre-allocation in sampling (#6132)

parent f0d8ca1e
......@@ -353,28 +353,22 @@ int64_t NumPickByEtype(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A tensor containing the picked neighbors.
* @param picked_data_ptr The destination address where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
*/
template <SamplerType S>
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
template <typename PickedType>
void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args);
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <>
torch::Tensor Pick<SamplerType::LABOR>(
template <typename PickedType>
void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args);
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
/**
* @brief Picks a specified number of neighbors for a node per edge type,
......@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>(
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
*
* @return A tensor containing the picked neighbors.
* @param picked_data_ptr The destination address where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
*/
template <SamplerType S>
torch::Tensor PickByEtype(
template <SamplerType S, typename PickedType>
void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
template <bool NonUniform, bool Replace, typename T = float>
torch::Tensor LaborPick(
template <
bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args);
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
} // namespace sampling
} // namespace graphbolt
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment