"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "55de50921f89ec06ed51381514c2710cabab1d8e"
Unverified Commit 3a79f021 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Make nodes optional for sampling (#6993)

parent 365bb723
...@@ -113,7 +113,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( ...@@ -113,7 +113,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* given nodes and their indptr values. * given nodes and their indptr values.
* *
* @param indptr The indptr tensor. * @param indptr The indptr tensor.
* @param nodes The nodes to read from indptr * @param nodes The nodes to read from indptr. If not provided, assumed to be
* equal to torch.arange(indptr.size(0) - 1).
* *
* @return Tuple of tensors with values: * @return Tuple of tensors with values:
* (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees * (indptr[nodes + 1] - indptr[nodes], indptr[nodes]), the returned indegrees
...@@ -121,7 +122,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl( ...@@ -121,7 +122,7 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
* on it gives the output indptr. * on it gives the output indptr.
*/ */
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes); torch::Tensor indptr, torch::optional<torch::Tensor> nodes);
/** /**
* @brief Given the compacted sub_indptr tensor, edge type tensor and * @brief Given the compacted sub_indptr tensor, edge type tensor and
......
...@@ -19,7 +19,8 @@ namespace ops { ...@@ -19,7 +19,8 @@ namespace ops {
* *
* @param indptr Index pointer array of the CSC. * @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC. * @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors. * @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param fanouts The number of edges to be sampled for each node with or * @param fanouts The number of edges to be sampled for each node with or
* without considering edge types. * without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all * - When the length is 1, it indicates that the fanout applies to all
...@@ -49,9 +50,9 @@ namespace ops { ...@@ -49,9 +50,9 @@ namespace ops {
* the sampled graph's information. * the sampled graph's information.
*/ */
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes, torch::Tensor indptr, torch::Tensor indices,
const std::vector<int64_t>& fanouts, bool replace, bool layer, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt, torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt); torch::optional<torch::Tensor> probs_or_mask = torch::nullopt);
......
...@@ -286,7 +286,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -286,7 +286,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced * @brief Sample neighboring edges of the given nodes and return the induced
* subgraph. * subgraph.
* *
* @param nodes The nodes from which to sample neighbors. * @param nodes The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()).
* @param fanouts The number of edges to be sampled for each node with or * @param fanouts The number of edges to be sampled for each node with or
* without considering edge types. * without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all * - When the length is 1, it indicates that the fanout applies to all
...@@ -317,7 +318,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -317,7 +318,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information. * the sampled graph's information.
*/ */
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const; torch::optional<std::string> probs_name) const;
......
...@@ -130,16 +130,18 @@ struct SegmentEndFunc { ...@@ -130,16 +130,18 @@ struct SegmentEndFunc {
}; };
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes, torch::Tensor indptr, torch::Tensor indices,
const std::vector<int64_t>& fanouts, bool replace, bool layer, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool return_eids, torch::optional<torch::Tensor> type_per_edge, bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask) { torch::optional<torch::Tensor> probs_or_mask) {
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!"); TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them // are all resident on the GPU. If not, it is better to first extract them
// before calling this function. // before calling this function.
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
auto num_rows = nodes.size(0); auto num_rows =
nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1;
auto fanouts_pinned = torch::empty( auto fanouts_pinned = torch::empty(
fanouts.size(), fanouts.size(),
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
...@@ -166,34 +168,49 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -166,34 +168,49 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceReduce::Max, in_degree.data_ptr<index_t>(), DeviceReduce::Max, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows); max_in_degree.data_ptr<index_t>(), num_rows);
})); }));
torch::optional<int64_t> num_edges_; // Protect access to max_in_degree with a CUDAEvent
at::cuda::CUDAEvent max_in_degree_event;
max_in_degree_event.record();
torch::optional<int64_t> num_edges;
torch::Tensor sub_indptr; torch::Tensor sub_indptr;
if (!nodes.has_value()) {
num_edges = indices.size(0);
sub_indptr = indptr;
}
torch::optional<torch::Tensor> sliced_probs_or_mask; torch::optional<torch::Tensor> sliced_probs_or_mask;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (nodes.has_value()) {
torch::Tensor sliced_probs_or_mask_tensor; torch::Tensor sliced_probs_or_mask_tensor;
std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl( std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
in_degree, sliced_indptr, probs_or_mask.value(), nodes, in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(),
indptr.size(0) - 2, num_edges_); indptr.size(0) - 2, num_edges);
sliced_probs_or_mask = sliced_probs_or_mask_tensor; sliced_probs_or_mask = sliced_probs_or_mask_tensor;
num_edges_ = sliced_probs_or_mask_tensor.size(0); num_edges = sliced_probs_or_mask_tensor.size(0);
} else {
sliced_probs_or_mask = probs_or_mask;
}
} }
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
torch::Tensor sliced_type_per_edge; torch::Tensor sliced_type_per_edge;
if (nodes.has_value()) {
std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl( std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
in_degree, sliced_indptr, type_per_edge.value(), nodes, in_degree, sliced_indptr, type_per_edge.value(), nodes.value(),
indptr.size(0) - 2, num_edges_); indptr.size(0) - 2, num_edges);
} else {
sliced_type_per_edge = type_per_edge.value();
}
std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero( std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size()); sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
num_rows = sliced_indptr.size(0); num_rows = sliced_indptr.size(0);
num_edges_ = sliced_type_per_edge.size(0); num_edges = sliced_type_per_edge.size(0);
} }
// If sub_indptr was not computed in the two code blocks above: // If sub_indptr was not computed in the two code blocks above:
if (!probs_or_mask.has_value() && fanouts.size() <= 1) { if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
sub_indptr = ExclusiveCumSum(in_degree); sub_indptr = ExclusiveCumSum(in_degree);
} }
auto coo_rows = ExpandIndptrImpl( auto coo_rows = ExpandIndptrImpl(
sub_indptr, indices.scalar_type(), torch::nullopt, num_edges_); sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
const auto num_edges = coo_rows.size(0); num_edges = coo_rows.size(0);
const auto random_seed = RandomEngine::ThreadLocal()->RandInt( const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
auto output_indptr = torch::empty_like(sub_indptr); auto output_indptr = torch::empty_like(sub_indptr);
...@@ -233,9 +250,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -233,9 +250,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto num_sampled_edges = auto num_sampled_edges =
cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows}; cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};
// Find the smallest integer type to store the edge id offsets. // Find the smallest integer type to store the edge id offsets. We synch
// ExpandIndptr or IndexSelectCSCImpl had synch inside, so it is safe to // the CUDAEvent so that the access is safe.
// read max_in_degree now. max_in_degree_event.synchronize();
const int num_bits = const int num_bits =
cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]); cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
std::array<int, 4> type_bits = {8, 16, 32, 64}; std::array<int, 4> type_bits = {8, 16, 32, 64};
...@@ -255,12 +272,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -255,12 +272,14 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Using bfloat16 for random numbers works just as reliably as // Using bfloat16 for random numbers works just as reliably as
// float32 and provides around %30 percent speedup. // float32 and provides around %30 percent speedup.
using rnd_t = nv_bfloat16; using rnd_t = nv_bfloat16;
auto randoms = allocator.AllocateStorage<rnd_t>(num_edges); auto randoms =
auto randoms_sorted = allocator.AllocateStorage<rnd_t>(num_edges); allocator.AllocateStorage<rnd_t>(num_edges.value());
auto randoms_sorted =
allocator.AllocateStorage<rnd_t>(num_edges.value());
auto edge_id_segments = auto edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges); allocator.AllocateStorage<edge_id_t>(num_edges.value());
auto sorted_edge_id_segments = auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges); allocator.AllocateStorage<edge_id_t>(num_edges.value());
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] { indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = index_t; using indices_t = index_t;
...@@ -282,10 +301,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -282,10 +301,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
layer ? indices.data_ptr<indices_t>() : nullptr; layer ? indices.data_ptr<indices_t>() : nullptr;
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid( const dim3 grid(
(num_edges + BLOCK_SIZE - 1) / BLOCK_SIZE); (num_edges.value() + BLOCK_SIZE - 1) /
BLOCK_SIZE);
// Compute row and random number pairs. // Compute row and random number pairs.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0, num_edges, _ComputeRandoms, grid, block, 0,
num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(), sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr, coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
...@@ -300,13 +321,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -300,13 +321,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUB_CALL( CUB_CALL(
DeviceSegmentedSort::SortPairs, randoms.get(), DeviceSegmentedSort::SortPairs, randoms.get(),
randoms_sorted.get(), edge_id_segments.get(), randoms_sorted.get(), edge_id_segments.get(),
sorted_edge_id_segments.get(), num_edges, num_rows, sorted_edge_id_segments.get(), num_edges.value(), num_rows,
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1); sub_indptr.data_ptr<indptr_t>() + 1);
picked_eids = torch::empty( picked_eids = torch::empty(
static_cast<indptr_t>(num_sampled_edges), static_cast<indptr_t>(num_sampled_edges),
nodes.options().dtype(indptr.scalar_type())); sub_indptr.options());
// Need to sort the sampled edges only when fanouts.size() == 1 // Need to sort the sampled edges only when fanouts.size() == 1
// since multiple fanout sampling case is automatically going to // since multiple fanout sampling case is automatically going to
...@@ -385,9 +406,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -385,9 +406,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size()); output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
if (!nodes.has_value()) {
nodes = torch::arange(indptr.size(0) - 1, indices.options());
}
return c10::make_intrusive<sampling::FusedSampledSubgraph>( return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes, torch::nullopt, output_indptr, output_indices, nodes.value(), torch::nullopt,
subgraph_reverse_edge_ids, output_type_per_edge); subgraph_reverse_edge_ids, output_type_per_edge);
} }
......
...@@ -35,14 +35,16 @@ struct SliceFunc { ...@@ -35,14 +35,16 @@ struct SliceFunc {
// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes]) // Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
torch::Tensor indptr, torch::Tensor nodes) { torch::Tensor indptr, torch::optional<torch::Tensor> nodes_optional) {
if (nodes_optional.has_value()) {
auto nodes = nodes_optional.value();
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// Read indptr only once in case it is pinned and access is slow. // Read indptr only once in case it is pinned and access is slow.
auto sliced_indptr = auto sliced_indptr =
torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type())); torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type()));
// compute in-degrees // compute in-degrees
auto in_degree = auto in_degree = torch::empty(
torch::empty(num_nodes + 1, nodes.options().dtype(indptr.scalar_type())); num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] { indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
...@@ -59,6 +61,22 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( ...@@ -59,6 +61,22 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
})); }));
})); }));
return {in_degree, sliced_indptr}; return {in_degree, sliced_indptr};
} else {
const int64_t num_nodes = indptr.size(0) - 1;
auto sliced_indptr = indptr.slice(0, 0, num_nodes);
auto in_degree = torch::empty(
num_nodes + 2, indptr.options().dtype(indptr.scalar_type()));
AT_DISPATCH_INTEGRAL_TYPES(
indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
using indptr_t = scalar_t;
CUB_CALL(
DeviceAdjacentDifference::SubtractLeftCopy,
indptr.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),
num_nodes + 1, cub::Difference{});
}));
in_degree = in_degree.slice(0, 1);
return {in_degree, sliced_indptr};
}
} }
template <typename indptr_t, typename etype_t> template <typename indptr_t, typename etype_t>
......
...@@ -607,21 +607,28 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -607,21 +607,28 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
} }
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const { torch::optional<std::string> probs_name) const {
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt; auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value() && !probs_name.value().empty()) {
probs_or_mask = this->EdgeAttribute(probs_name);
}
if (!replace && utils::is_on_gpu(nodes) && // If nodes does not have a value, then we expect all arguments to be resident
// on the GPU. If nodes has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available.
if (((!nodes.has_value() && utils::is_on_gpu(indptr_) &&
utils::is_on_gpu(indices_) &&
(!probs_or_mask.has_value() ||
utils::is_on_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() ||
utils::is_on_gpu(type_per_edge_.value()))) ||
(nodes.has_value() && utils::is_on_gpu(nodes.value()) &&
utils::is_accessible_from_gpu(indptr_) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) && utils::is_accessible_from_gpu(indices_) &&
(!probs_or_mask.has_value() || (!probs_or_mask.has_value() ||
utils::is_accessible_from_gpu(probs_or_mask.value())) && utils::is_accessible_from_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() || (!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) { utils::is_accessible_from_gpu(type_per_edge_.value())))) &&
!replace) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "SampleNeighbors", { c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors( return ops::SampleNeighbors(
...@@ -629,6 +636,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -629,6 +636,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
type_per_edge_, probs_or_mask); type_per_edge_, probs_or_mask);
}); });
} }
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU.");
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper // Note probs will be passed as input for 'torch.multinomial' in deeper
...@@ -645,7 +653,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -645,7 +653,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()); static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()}; SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes, return_eids, nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn( GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
...@@ -653,7 +661,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -653,7 +661,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes, return_eids, nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn( GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
......
...@@ -597,6 +597,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -597,6 +597,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
if nodes is not None:
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert nodes.dtype == self.indices.dtype, ( assert nodes.dtype == self.indices.dtype, (
f"Data type of nodes must be consistent with " f"Data type of nodes must be consistent with "
......
...@@ -1615,7 +1615,10 @@ def test_csc_sampling_graph_to_pinned_memory(): ...@@ -1615,7 +1615,10 @@ def test_csc_sampling_graph_to_pinned_memory():
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("is_pinned", [False, True]) @pytest.mark.parametrize("is_pinned", [False, True])
def test_sample_neighbors_homo(labor, is_pinned): @pytest.mark.parametrize("nodes", [None, True])
def test_sample_neighbors_homo(labor, is_pinned, nodes):
if is_pinned and nodes is None:
pytest.skip("Optional nodes and is_pinned is not supported together.")
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -1638,13 +1641,20 @@ def test_sample_neighbors_homo(labor, is_pinned): ...@@ -1638,13 +1641,20 @@ def test_sample_neighbors_homo(labor, is_pinned):
) )
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
if nodes:
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx()) nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
elif F._default_context_str != "gpu":
pytest.skip("Optional nodes is supported only for the GPU.")
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2])) subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# Verify in subgraph. # Verify in subgraph.
sampled_indptr_num = subgraph.sampled_csc.indptr.size(0) sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
sampled_num = subgraph.sampled_csc.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
if nodes is None:
assert sampled_indptr_num == indptr.shape[0]
assert sampled_num == 10
else:
assert sampled_indptr_num == 4 assert sampled_indptr_num == 4
assert sampled_num == 6 assert sampled_num == 6
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment