"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1b42732ced07861b810f77ecf3fc8ce63ce465e8"
Unverified Commit 58d98e03 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Utilize pre-allocation in sampling (#6132)

parent f0d8ca1e
...@@ -353,28 +353,22 @@ int64_t NumPickByEtype( ...@@ -353,28 +353,22 @@ int64_t NumPickByEtype(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <typename PickedType>
torch::Tensor Pick( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args); SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::LABOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
/** /**
* @brief Picks a specified number of neighbors for a node per edge type, * @brief Picks a specified number of neighbors for a node per edge type,
...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>( ...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <SamplerType S, typename PickedType>
torch::Tensor PickByEtype( void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args); const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
template <bool NonUniform, bool Replace, typename T = float> template <
torch::Tensor LaborPick( bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment