Unverified Commit d3483fe1 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Optimize hetero sampling. (#7223)

parent 2a00cd3d
...@@ -19,8 +19,10 @@ namespace ops { ...@@ -19,8 +19,10 @@ namespace ops {
* *
* @param indptr Index pointer array of the CSC. * @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC. * @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors. If not provided, * @param seeds The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1). * assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
* @param fanouts The number of edges to be sampled for each node with or * @param fanouts The number of edges to be sampled for each node with or
* without considering edge types. * without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all * - When the length is 1, it indicates that the fanout applies to all
...@@ -45,6 +47,12 @@ namespace ops { ...@@ -45,6 +47,12 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities * @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be * corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges. * a 1D tensor, with the number of elements equaling the total number of edges.
* @param node_type_to_id A dictionary mapping node type names to type IDs. The
* length of it is equal to the number of node types. The key is the node type
* name, and the value is the corresponding type ID.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs. The
* length of it is equal to the number of edge types. The key is the edge type
* name, and the value is the corresponding type ID.
* @param random_seed The random seed for the sampler for layer=True. * @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1) * @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True. * for layer=True.
...@@ -54,10 +62,16 @@ namespace ops { ...@@ -54,10 +62,16 @@ namespace ops {
*/ */
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> seeds,
bool replace, bool layer, bool return_eids, torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt, torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt, torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id =
torch::nullopt,
torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id =
torch::nullopt,
torch::optional<torch::Tensor> random_seed = torch::nullopt, torch::optional<torch::Tensor> random_seed = torch::nullopt,
float seed2_contribution = .0f); float seed2_contribution = .0f);
......
...@@ -298,8 +298,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -298,8 +298,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced * @brief Sample neighboring edges of the given nodes and return the induced
* subgraph. * subgraph.
* *
* @param nodes The nodes from which to sample neighbors. If not provided, * @param seeds The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()). * assumed to be equal to torch.arange(NumNodes()).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type id i.
* @param fanouts The number of edges to be sampled for each node with or * @param fanouts The number of edges to be sampled for each node with or
* without considering edge types. * without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all * - When the length is 1, it indicates that the fanout applies to all
...@@ -333,9 +335,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -333,9 +335,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information. * the sampled graph's information.
*/ */
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> seeds,
bool replace, bool layer, bool return_eids, torch::optional<std::vector<int64_t>> seed_offsets,
torch::optional<std::string> probs_name, const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed, torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const; double seed2_contribution) const;
......
...@@ -51,33 +51,39 @@ struct FusedSampledSubgraph : torch::CustomClassHolder { ...@@ -51,33 +51,39 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* graph. * graph.
* @param original_edge_ids Reverse edge ids in the original graph. * @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge. * @param type_per_edge Type id of each edge.
* @param etype_offsets Edge offsets for the sampled edges for the sampled
* edges that are sorted w.r.t. edge types.
*/ */
FusedSampledSubgraph( FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::Tensor original_column_node_ids, torch::optional<torch::Tensor> original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt, torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt, torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt) torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> etype_offsets = torch::nullopt)
: indptr(indptr), : indptr(indptr),
indices(indices), indices(indices),
original_column_node_ids(original_column_node_ids), original_column_node_ids(original_column_node_ids),
original_row_node_ids(original_row_node_ids), original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids), original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {} type_per_edge(type_per_edge),
etype_offsets(etype_offsets) {}
FusedSampledSubgraph() = default; FusedSampledSubgraph() = default;
/** /**
* @brief CSC format index pointer array, where the implicit node ids are * @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the * already compacted. And the original ids are stored in the
* `original_column_node_ids` field. * `original_column_node_ids` field. Its length is equal to:
* 1 + \sum_{etype} #seeds with dst_node_type(etype)
*/ */
torch::Tensor indptr; torch::Tensor indptr;
/** /**
* @brief CSC format index array, where the node ids can be compacted ids or * @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the * original ids. If compacted, the original ids are stored in the
* `original_row_node_ids` field. * `original_row_node_ids` field. The indices are sorted w.r.t. their edge
* types for the heterogenous case.
*/ */
torch::Tensor indices; torch::Tensor indices;
...@@ -86,10 +92,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder { ...@@ -86,10 +92,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* can be treated as a coordinated row and column pair, and this is the the * can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column. * mapped ids of the column.
* *
* @note This is required and the mapping relations can be inconsistent with * @note This is optional and the mapping relations can be inconsistent with
* column's. * column's. It can be missing when the sampling algorithm is called via a
* sliced sampled subgraph with missing seeds argument.
*/ */
torch::Tensor original_column_node_ids; torch::optional<torch::Tensor> original_column_node_ids;
/** /**
* @brief Row's reverse node ids in the original graph. A graph structure * @brief Row's reverse node ids in the original graph. A graph structure
...@@ -104,7 +111,8 @@ struct FusedSampledSubgraph : torch::CustomClassHolder { ...@@ -104,7 +111,8 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
/** /**
* @brief Reverse edge ids in the original graph, the edge with id * @brief Reverse edge ids in the original graph, the edge with id
* `original_edge_ids[i]` in the original graph is mapped to `i` in this * `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed. * subgraph. This is useful when edge features are needed. The edges are
* sorted w.r.t. their edge types for the heterogenous case.
*/ */
torch::optional<torch::Tensor> original_edge_ids; torch::optional<torch::Tensor> original_edge_ids;
...@@ -112,8 +120,21 @@ struct FusedSampledSubgraph : torch::CustomClassHolder { ...@@ -112,8 +120,21 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* @brief Type id of each edge, where type id is the corresponding index of * @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges in the * edge types. The length of it is equal to the number of edges in the
* subgraph. * subgraph.
*
* @note This output is not created by the CUDA implementation as the edges
* are sorted w.r.t edge types, one has to use etype_offsets to infer the edge
* type information. This field is going to be deprecated. It can be generated
* when needed by computing gb.expand_indptr(etype_offsets).
*/ */
torch::optional<torch::Tensor> type_per_edge; torch::optional<torch::Tensor> type_per_edge;
/**
* @brief Offsets of each etype,
* type_per_edge[etype_offsets[i]: etype_offsets[i + 1]] == i
* It has length equal to (1 + #etype), and the edges are guaranteed to be
* sorted w.r.t. their edge types.
*/
torch::optional<torch::Tensor> etype_offsets;
}; };
} // namespace sampling } // namespace sampling
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <type_traits> #include <type_traits>
#include "../random.h" #include "../random.h"
#include "../utils.h"
#include "./common.h" #include "./common.h"
#include "./utils.h" #include "./utils.h"
...@@ -183,19 +184,26 @@ struct SegmentEndFunc { ...@@ -183,19 +184,26 @@ struct SegmentEndFunc {
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> seeds,
bool replace, bool layer, bool return_eids, torch::optional<std::vector<int64_t>> seed_offsets,
torch::optional<torch::Tensor> type_per_edge, const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask, torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id,
torch::optional<torch::Tensor> random_seed_tensor, torch::optional<torch::Tensor> random_seed_tensor,
float seed2_contribution) { float seed2_contribution) {
// When seed_offsets.has_value() in the hetero case, we compute the output of
// sample_neighbors _convert_to_sampled_subgraph in a fused manner so that
// _convert_to_sampled_subgraph only has to perform slices over the returned
// indptr and indices tensors to form CSC outputs for each edge type.
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!"); TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask // Assume that indptr, indices, seeds, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them // are all resident on the GPU. If not, it is better to first extract them
// before calling this function. // before calling this function.
auto allocator = cuda::GetAllocator(); auto allocator = cuda::GetAllocator();
auto num_rows = auto num_rows =
nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1; seeds.has_value() ? seeds.value().size(0) : indptr.size(0) - 1;
auto fanouts_pinned = torch::empty( auto fanouts_pinned = torch::empty(
fanouts.size(), fanouts.size(),
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true)); c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
...@@ -210,7 +218,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -210,7 +218,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
fanouts_device.get(), fanouts_pinned_ptr, fanouts_device.get(), fanouts_pinned_ptr,
sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice, sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
cuda::GetCurrentStream())); cuda::GetCurrentStream()));
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, seeds);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
auto max_in_degree = torch::empty( auto max_in_degree = torch::empty(
...@@ -227,16 +235,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -227,16 +235,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
max_in_degree_event.record(); max_in_degree_event.record();
torch::optional<int64_t> num_edges; torch::optional<int64_t> num_edges;
torch::Tensor sub_indptr; torch::Tensor sub_indptr;
if (!nodes.has_value()) { if (!seeds.has_value()) {
num_edges = indices.size(0); num_edges = indices.size(0);
sub_indptr = indptr; sub_indptr = indptr;
} }
torch::optional<torch::Tensor> sliced_probs_or_mask; torch::optional<torch::Tensor> sliced_probs_or_mask;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (nodes.has_value()) { if (seeds.has_value()) {
torch::Tensor sliced_probs_or_mask_tensor; torch::Tensor sliced_probs_or_mask_tensor;
std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl( std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(), in_degree, sliced_indptr, probs_or_mask.value(), seeds.value(),
indptr.size(0) - 2, num_edges); indptr.size(0) - 2, num_edges);
sliced_probs_or_mask = sliced_probs_or_mask_tensor; sliced_probs_or_mask = sliced_probs_or_mask_tensor;
num_edges = sliced_probs_or_mask_tensor.size(0); num_edges = sliced_probs_or_mask_tensor.size(0);
...@@ -246,9 +254,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -246,9 +254,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
} }
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
torch::Tensor sliced_type_per_edge; torch::Tensor sliced_type_per_edge;
if (nodes.has_value()) { if (seeds.has_value()) {
std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl( std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
in_degree, sliced_indptr, type_per_edge.value(), nodes.value(), in_degree, sliced_indptr, type_per_edge.value(), seeds.value(),
indptr.size(0) - 2, num_edges); indptr.size(0) - 2, num_edges);
} else { } else {
sliced_type_per_edge = type_per_edge.value(); sliced_type_per_edge = type_per_edge.value();
...@@ -259,7 +267,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -259,7 +267,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
num_edges = sliced_type_per_edge.size(0); num_edges = sliced_type_per_edge.size(0);
} }
// If sub_indptr was not computed in the two code blocks above: // If sub_indptr was not computed in the two code blocks above:
if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) { if (seeds.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
sub_indptr = ExclusiveCumSum(in_degree); sub_indptr = ExclusiveCumSum(in_degree);
} }
auto coo_rows = ExpandIndptrImpl( auto coo_rows = ExpandIndptrImpl(
...@@ -276,7 +284,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -276,7 +284,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto output_indptr = torch::empty_like(sub_indptr); auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids; torch::Tensor picked_eids;
torch::Tensor output_indices; torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
...@@ -507,39 +514,153 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -507,39 +514,153 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
indices.data_ptr<indices_t>(), indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>()); output_indices.data_ptr<indices_t>());
})); }));
}));
if (type_per_edge) { auto index_type_per_edge_for_sampled_edges = [&] {
// output_type_per_edge = type_per_edge.gather(0, picked_eids); // The code behaves same as:
// The commented out torch equivalent above does not work when // output_type_per_edge = type_per_edge.gather(0, picked_eids);
// type_per_edge is on pinned memory. That is why, we have to // The reimplementation is required due to the torch equivalent does
// reimplement it, similar to the indices gather operation above. // not work when type_per_edge is on pinned memory
auto types = type_per_edge.value(); auto types = type_per_edge.value();
output_type_per_edge = torch::empty( auto output = torch::empty(
picked_eids.size(0), picked_eids.size(0), picked_eids.options().dtype(types.scalar_type()));
picked_eids.options().dtype(types.scalar_type())); AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] { types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
THRUST_CALL( THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(), gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0), picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(), types.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
output_type_per_edge.value().data_ptr<scalar_t>());
})); }));
} }));
})); return output;
};
torch::optional<torch::Tensor> output_type_per_edge;
torch::optional<torch::Tensor> edge_offsets;
if (type_per_edge && seed_offsets) {
const int64_t num_etypes =
edge_type_to_id.has_value() ? edge_type_to_id->size() : 1;
// If we performed homogenous sampling on hetero graph, we have to look at
// type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation.
if (fanouts.size() == 1) {
output_type_per_edge = index_type_per_edge_for_sampled_edges();
torch::Tensor output_in_degree, sliced_output_indptr;
sliced_output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0) - 1);
std::tie(output_indptr, output_in_degree, sliced_output_indptr) =
SliceCSCIndptrHetero(
output_indptr, output_type_per_edge.value(), sliced_output_indptr,
num_etypes);
// We use num_rows to hold num_seeds * num_etypes. So, it needs to be
// updated when sampling with a single fanout value when the graph is
// heterogenous.
num_rows = sliced_output_indptr.size(0);
}
// Here, we check what are the dst node types for the given seeds so that
// we can compute the output indptr space later.
std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);
for (auto& etype_and_id : edge_type_to_id.value()) {
auto etype = etype_and_id.key();
auto id = etype_and_id.value();
auto dst_type = utils::parse_dst_ntype_from_etype(etype);
etype_id_to_dst_ntype_id[id] = node_type_to_id->at(dst_type);
}
// For each edge type, we compute the start and end offsets to index into
// indptr to form the final output_indptr.
auto indptr_offsets = torch::empty(
num_etypes * 2,
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
auto indptr_offsets_ptr = indptr_offsets.data_ptr<int64_t>();
// We compute the indptr offsets here, right now, output_indptr is of size
// # seeds * num_etypes + 1. We can simply take slices to get correct output
// indptr. The final output_indptr is same as current indptr except that
// some intermediate values are removed to change the node ids space from
// all of the seed vertices to the node id space of the dst node type of
// each edge type.
for (int i = 0; i < num_etypes; i++) {
indptr_offsets_ptr[2 * i] = num_rows / num_etypes * i +
seed_offsets->at(etype_id_to_dst_ntype_id[i]);
indptr_offsets_ptr[2 * i + 1] =
num_rows / num_etypes * i +
seed_offsets->at(etype_id_to_dst_ntype_id[i] + 1);
}
auto permutation = torch::arange(
0, num_rows * num_etypes, num_etypes, output_indptr.options());
permutation =
permutation.remainder(num_rows) + permutation.div(num_rows, "floor");
// This permutation, when applied sorts the sampled edges with respect to
// edge types.
auto [output_in_degree, sliced_output_indptr] =
SliceCSCIndptr(output_indptr, permutation);
std::tie(output_indptr, picked_eids) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, picked_eids, permutation,
num_rows - 1, picked_eids.size(0));
edge_offsets = torch::empty(
num_etypes * 2, c10::TensorOptions()
.dtype(output_indptr.scalar_type())
.pinned_memory(true));
at::cuda::CUDAEvent edge_offsets_event;
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsEdgeOffsets", ([&] {
THRUST_CALL(
gather, indptr_offsets_ptr,
indptr_offsets_ptr + indptr_offsets.size(0),
output_indptr.data_ptr<index_t>(),
edge_offsets->data_ptr<index_t>());
}));
edge_offsets_event.record();
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, output_indices, permutation,
num_rows - 1, output_indices.size(0));
std::vector<torch::Tensor> indptr_list;
for (int i = 0; i < num_etypes; i++) {
indptr_list.push_back(output_indptr.slice(
0, indptr_offsets_ptr[2 * i],
indptr_offsets_ptr[2 * i + 1] + (i == num_etypes - 1)));
}
// We form the final output indptr by concatenating pieces for different
// edge types.
output_indptr = torch::cat(indptr_list);
edge_offsets_event.synchronize();
// We read the edge_offsets here, they are in pairs but we don't need it to
// be in pairs. So we remove the duplicate information from it and turn it
// into a real offsets array.
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsEdgeOffsetsCheck", ([&] {
auto edge_offsets_ptr = edge_offsets->data_ptr<index_t>();
TORCH_CHECK(edge_offsets_ptr[0] == 0, "edge_offsets is incorrect.");
for (int i = 1; i < num_etypes; i++) {
TORCH_CHECK(
edge_offsets_ptr[2 * i - 1] == edge_offsets_ptr[2 * i],
"edge_offsets is incorrect.");
}
TORCH_CHECK(
edge_offsets_ptr[2 * num_etypes - 1] == picked_eids.size(0),
"edge_offsets is incorrect.");
for (int i = 0; i < num_etypes; i++) {
edge_offsets_ptr[i + 1] = edge_offsets_ptr[2 * i + 1];
}
}));
edge_offsets = edge_offsets->slice(0, 0, num_etypes + 1);
} else {
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
if (type_per_edge)
output_type_per_edge = index_type_per_edge_for_sampled_edges();
}
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
if (!nodes.has_value()) {
nodes = torch::arange(indptr.size(0) - 1, indices.options());
}
return c10::make_intrusive<sampling::FusedSampledSubgraph>( return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes.value(), torch::nullopt, output_indptr, output_indices, seeds, torch::nullopt,
subgraph_reverse_edge_ids, output_type_per_edge); subgraph_reverse_edge_ids, output_type_per_edge, edge_offsets);
} }
} // namespace ops } // namespace ops
......
...@@ -617,23 +617,24 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -617,23 +617,24 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
} }
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts, torch::optional<torch::Tensor> seeds,
bool replace, bool layer, bool return_eids, torch::optional<std::vector<int64_t>> seed_offsets,
torch::optional<std::string> probs_name, const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed, torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const { double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name); auto probs_or_mask = this->EdgeAttribute(probs_name);
// If nodes does not have a value, then we expect all arguments to be resident // If seeds does not have a value, then we expect all arguments to be resident
// on the GPU. If nodes has a value, then we expect them to be accessible from // on the GPU. If seeds has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available. // GPU. This is required for the dispatch to work when CUDA is not available.
if (((!nodes.has_value() && utils::is_on_gpu(indptr_) && if (((!seeds.has_value() && utils::is_on_gpu(indptr_) &&
utils::is_on_gpu(indices_) && utils::is_on_gpu(indices_) &&
(!probs_or_mask.has_value() || (!probs_or_mask.has_value() ||
utils::is_on_gpu(probs_or_mask.value())) && utils::is_on_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() || (!type_per_edge_.has_value() ||
utils::is_on_gpu(type_per_edge_.value()))) || utils::is_on_gpu(type_per_edge_.value()))) ||
(nodes.has_value() && utils::is_on_gpu(nodes.value()) && (seeds.has_value() && utils::is_on_gpu(seeds.value()) &&
utils::is_accessible_from_gpu(indptr_) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) && utils::is_accessible_from_gpu(indices_) &&
(!probs_or_mask.has_value() || (!probs_or_mask.has_value() ||
...@@ -644,11 +645,12 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -644,11 +645,12 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "SampleNeighbors", { c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors( return ops::SampleNeighbors(
indptr_, indices_, nodes, fanouts, replace, layer, return_eids, indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
type_per_edge_, probs_or_mask, random_seed, seed2_contribution); return_eids, type_per_edge_, probs_or_mask, node_type_to_id_,
edge_type_to_id_, random_seed, seed2_contribution);
}); });
} }
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU."); TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper // Note probs will be passed as input for 'torch.multinomial' in deeper
...@@ -667,7 +669,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -667,7 +669,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{random_seed.value(), static_cast<float>(seed2_contribution)}, {random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()}; NumNodes()};
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes.value(), return_eids, seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn( GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, fanouts, replace, indptr_.options(), type_per_edge_,
...@@ -686,7 +688,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -686,7 +688,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} }
}(); }();
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes.value(), return_eids, seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn( GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, fanouts, replace, indptr_.options(), type_per_edge_,
...@@ -695,7 +697,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -695,7 +697,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl( return SampleNeighborsImpl(
nodes.value(), return_eids, seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn( GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
......
...@@ -36,7 +36,8 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -36,7 +36,8 @@ TORCH_LIBRARY(graphbolt, m) {
&FusedSampledSubgraph::original_column_node_ids) &FusedSampledSubgraph::original_column_node_ids)
.def_readwrite( .def_readwrite(
"original_edge_ids", &FusedSampledSubgraph::original_edge_ids) "original_edge_ids", &FusedSampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge); .def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge)
.def_readwrite("etype_offsets", &FusedSampledSubgraph::etype_offsets);
m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray") m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
.def("index_select", &storage::OnDiskNpyArray::IndexSelect); .def("index_select", &storage::OnDiskNpyArray::IndexSelect);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph") m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
......
...@@ -26,6 +26,17 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) { ...@@ -26,6 +26,17 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return is_on_gpu(tensor) || tensor.is_pinned(); return is_on_gpu(tensor) || tensor.is_pinned();
} }
/**
* @brief Parses the destination node type from a given edge type triple
* seperated with ":".
*/
inline std::string parse_dst_ntype_from_etype(std::string etype) {
auto first_seperator_it = std::find(etype.begin(), etype.end(), ':');
auto second_seperator_pos =
std::find(first_seperator_it + 1, etype.end(), ':') - etype.begin();
return etype.substr(second_seperator_pos + 1);
}
/** /**
* @brief Retrieves the value of the tensor at the given index. * @brief Retrieves the value of the tensor at the given index.
* *
......
...@@ -146,7 +146,7 @@ def _sample_neighbors_graphbolt( ...@@ -146,7 +146,7 @@ def _sample_neighbors_graphbolt(
return_eids = g.edge_attributes is not None and EID in g.edge_attributes return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors( subgraph = g._sample_neighbors(
nodes, fanout, replace=replace, return_eids=return_eids nodes, None, fanout, replace=replace, return_eids=return_eids
) )
# 3. Map local node IDs to global node IDs. # 3. Map local node IDs to global node IDs.
......
...@@ -444,7 +444,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -444,7 +444,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
)} )}
""" """
if isinstance(nodes, dict): if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes) nodes, _ = self._convert_to_homogeneous_nodes(nodes)
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
...@@ -453,22 +453,28 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -453,22 +453,28 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = [] homogeneous_nodes = []
homogeneous_node_offsets = [0]
homogeneous_timestamps = [] homogeneous_timestamps = []
offset = self._node_type_offset_list offset = self._node_type_offset_list
for ntype, ids in nodes.items(): for ntype, ntype_id in self.node_type_to_id.items():
ntype_id = self.node_type_to_id[ntype] ids = nodes.get(ntype, [])
homogeneous_nodes.append(ids + offset[ntype_id]) if len(ids) > 0:
if timestamps is not None: homogeneous_nodes.append(ids + offset[ntype_id])
homogeneous_timestamps.append(timestamps[ntype]) if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype])
homogeneous_node_offsets.append(
homogeneous_node_offsets[-1] + len(ids)
)
if timestamps is not None: if timestamps is not None:
return torch.cat(homogeneous_nodes), torch.cat( return torch.cat(homogeneous_nodes), torch.cat(
homogeneous_timestamps homogeneous_timestamps
) )
return torch.cat(homogeneous_nodes) return torch.cat(homogeneous_nodes), homogeneous_node_offsets
def _convert_to_sampled_subgraph( def _convert_to_sampled_subgraph(
self, self,
C_sampled_subgraph: torch.ScriptObject, C_sampled_subgraph: torch.ScriptObject,
seed_offsets: Optional[list] = None,
) -> SampledSubgraphImpl: ) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled """An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'.""" subgraph to general struct 'SampledSubgraphImpl'."""
...@@ -477,6 +483,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -477,6 +483,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
type_per_edge = C_sampled_subgraph.type_per_edge type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids original_edge_ids = C_sampled_subgraph.original_edge_ids
etype_offsets = C_sampled_subgraph.etype_offsets
if etype_offsets is not None:
etype_offsets = etype_offsets.tolist()
has_original_eids = ( has_original_eids = (
self.edge_attributes is not None self.edge_attributes is not None
...@@ -486,45 +495,78 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -486,45 +495,78 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_edge_ids = torch.ops.graphbolt.index_select( original_edge_ids = torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], original_edge_ids self.edge_attributes[ORIGINAL_EDGE_ID], original_edge_ids
) )
if type_per_edge is None: if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices) sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else: else:
# UVA sampling requires us to move node_type_offset to GPU. offset = self._node_type_offset_list
self.node_type_offset = self.node_type_offset.to(column.device)
# 1. Find node types for each nodes in column.
node_types = (
torch.searchsorted(self.node_type_offset, column, right=True)
- 1
)
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
sub_indices = {} sub_indices = {}
sub_indptr = {} sub_indptr = {}
offset = self._node_type_offset_list if etype_offsets is None:
# 2. For loop each node type. # UVA sampling requires us to move node_type_offset to GPU.
for ntype, ntype_id in self.node_type_to_id.items(): self.node_type_offset = self.node_type_offset.to(column.device)
# Get all nodes of a specific node type in column. # 1. Find node types for each nodes in column.
nids = torch.nonzero(node_types == ntype_id).view(-1) node_types = (
nids_original_indptr = indptr[nids + 1] torch.searchsorted(
self.node_type_offset, column, right=True
)
- 1
)
for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column.
nids = torch.nonzero(node_types == ntype_id).view(-1)
nids_original_indptr = indptr[nids + 1]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
if dst_ntype != ntype:
continue
# Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] = (
indices[eids] - offset[src_ntype_id]
)
cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False
)
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
else:
edge_offsets = [0]
for etype, etype_id in self.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype) src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
if dst_ntype != ntype: ntype_id = self.node_type_to_id[dst_ntype]
continue edge_offsets.append(
# Get all edge ids of a specific edge type. edge_offsets[-1]
eids = torch.nonzero(type_per_edge == etype_id).view(-1) + seed_offsets[ntype_id + 1]
src_ntype_id = self.node_type_to_id[src_ntype] - seed_offsets[ntype_id]
sub_indices[etype] = indices[eids] - offset[src_ntype_id]
cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False
)
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
) )
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
ntype_id = self.node_type_to_id[dst_ntype]
sub_indptr_ = indptr[
edge_offsets[etype_id] : edge_offsets[etype_id + 1] + 1
]
sub_indptr[etype] = sub_indptr_ - sub_indptr_[0]
sub_indices[etype] = indices[
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
]
if has_original_eids: if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[ original_hetero_edge_ids[etype] = original_edge_ids[
eids etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
] ]
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] -= offset[src_ntype_id]
if has_original_eids: if has_original_eids:
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
sampled_csc = { sampled_csc = {
...@@ -541,7 +583,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -541,7 +583,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def sample_neighbors( def sample_neighbors(
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
...@@ -551,7 +593,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -551,7 +593,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters Parameters
---------- ----------
nodes: torch.Tensor or Dict[str, torch.Tensor] seeds: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes. IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous - If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids. graph, and ids inside are homogeneous ids.
...@@ -615,21 +657,27 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -615,21 +657,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]), indices=tensor([2]),
)} )}
""" """
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
return_eids = ( return_eids = (
self.edge_attributes is not None self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes and ORIGINAL_EDGE_ID in self.edge_attributes
) )
seed_offsets = None
if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"):
seed_offsets = self._seed_offset_list # pylint: disable=no-member
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
nodes, seeds,
seed_offsets,
fanouts, fanouts,
replace=replace, replace=replace,
probs_name=probs_name, probs_name=probs_name,
return_eids=return_eids, return_eids=return_eids,
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
)
def _check_sampler_arguments(self, nodes, fanouts, probs_name): def _check_sampler_arguments(self, nodes, fanouts, probs_name):
if nodes is not None: if nodes is not None:
...@@ -676,7 +724,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -676,7 +724,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _sample_neighbors( def _sample_neighbors(
self, self,
nodes: torch.Tensor, seeds: torch.Tensor,
seed_offsets: Optional[list],
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
...@@ -687,8 +736,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -687,8 +736,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters Parameters
---------- ----------
nodes: torch.Tensor seeds: torch.Tensor
IDs of the given seed nodes. IDs of the given seed nodes.
seeds_offsets: list, optional
The offsets of the given seeds,
seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
fanouts: torch.Tensor fanouts: torch.Tensor
The number of edges to be sampled for each node with or without The number of edges to be sampled for each node with or without
considering edge types. considering edge types.
...@@ -726,9 +778,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -726,9 +778,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
The sampled C subgraph. The sampled C subgraph.
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name) self._check_sampler_arguments(seeds, fanouts, probs_name)
return self._c_csc_graph.sample_neighbors( return self._c_csc_graph.sample_neighbors(
nodes, seeds,
seed_offsets,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
False, # is_labor False, # is_labor
...@@ -740,7 +793,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -740,7 +793,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def sample_layer_neighbors( def sample_layer_neighbors(
self, self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]], seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
probs_name: Optional[str] = None, probs_name: Optional[str] = None,
...@@ -754,7 +807,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -754,7 +807,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters Parameters
---------- ----------
nodes: torch.Tensor or Dict[str, torch.Tensor] seeds: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes. IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous - If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids. graph, and ids inside are homogeneous ids.
...@@ -844,10 +897,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -844,10 +897,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]), indices=tensor([2]),
)} )}
""" """
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
self._check_sampler_arguments(nodes, fanouts, probs_name)
if random_seed is not None: if random_seed is not None:
assert ( assert (
1 <= len(random_seed) <= 2 1 <= len(random_seed) <= 2
...@@ -856,12 +905,21 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -856,12 +905,21 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert ( assert (
0 <= seed2_contribution <= 1 0 <= seed2_contribution <= 1
), "seed2_contribution should be in [0, 1]." ), "seed2_contribution should be in [0, 1]."
has_original_eids = ( has_original_eids = (
self.edge_attributes is not None self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes and ORIGINAL_EDGE_ID in self.edge_attributes
) )
seed_offsets = None
if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"):
seed_offsets = self._seed_offset_list # pylint: disable=no-member
self._check_sampler_arguments(seeds, fanouts, probs_name)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors( C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
nodes, seeds,
seed_offsets,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
True, True,
...@@ -870,7 +928,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -870,7 +928,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
random_seed, random_seed,
seed2_contribution, seed2_contribution,
) )
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
)
def temporal_sample_neighbors( def temporal_sample_neighbors(
self, self,
......
...@@ -46,44 +46,37 @@ class FetchInsubgraphData(Mapper): ...@@ -46,44 +46,37 @@ class FetchInsubgraphData(Mapper):
def _fetch_per_layer_impl(self, minibatch, stream): def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
index = minibatch._seed_nodes seeds = minibatch._seed_nodes
if isinstance(index, dict): is_hetero = isinstance(seeds, dict)
for idx in index.values(): if is_hetero:
for idx in seeds.values():
idx.record_stream(torch.cuda.current_stream()) idx.record_stream(torch.cuda.current_stream())
index = self.graph._convert_to_homogeneous_nodes(index) (
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else: else:
index.record_stream(torch.cuda.current_stream()) seeds.record_stream(torch.cuda.current_stream())
seed_offsets = None
def record_stream(tensor): def record_stream(tensor):
if stream is not None and tensor.is_cuda: if stream is not None and tensor.is_cuda:
tensor.record_stream(stream) tensor.record_stream(stream)
return tensor return tensor
if self.graph.node_type_offset is None:
# sorting not needed.
minibatch._subgraph_seed_nodes = None
else:
index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item():
# already sorted.
minibatch._subgraph_seed_nodes = None
else:
minibatch._subgraph_seed_nodes = record_stream(
original_positions.sort()[1]
)
index_select_csc_with_indptr = partial( index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
) )
indptr, indices = index_select_csc_with_indptr( indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None self.graph.indices, seeds, None
) )
record_stream(indptr) record_stream(indptr)
record_stream(indices) record_stream(indices)
output_size = len(indices) output_size = len(indices)
if self.graph.type_per_edge is not None: if self.graph.type_per_edge is not None:
_, type_per_edge = index_select_csc_with_indptr( _, type_per_edge = index_select_csc_with_indptr(
self.graph.type_per_edge, index, output_size self.graph.type_per_edge, seeds, output_size
) )
record_stream(type_per_edge) record_stream(type_per_edge)
else: else:
...@@ -94,27 +87,22 @@ class FetchInsubgraphData(Mapper): ...@@ -94,27 +87,22 @@ class FetchInsubgraphData(Mapper):
) )
if probs_or_mask is not None: if probs_or_mask is not None:
_, probs_or_mask = index_select_csc_with_indptr( _, probs_or_mask = index_select_csc_with_indptr(
probs_or_mask, index, output_size probs_or_mask, seeds, output_size
) )
record_stream(probs_or_mask) record_stream(probs_or_mask)
else: else:
probs_or_mask = None probs_or_mask = None
if self.graph.node_type_offset is not None:
node_type_offset = torch.searchsorted(
index, self.graph.node_type_offset
)
else:
node_type_offset = None
subgraph = fused_csc_sampling_graph( subgraph = fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id, node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id, edge_type_to_id=self.graph.edge_type_to_id,
) )
if self.prob_name is not None and probs_or_mask is not None: if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask} subgraph.edge_attributes = {self.prob_name: probs_or_mask}
subgraph._seed_offset_list = seed_offsets
minibatch.sampled_subgraphs.insert(0, subgraph) minibatch.sampled_subgraphs.insert(0, subgraph)
...@@ -152,14 +140,12 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): ...@@ -152,14 +140,12 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
if hasattr(minibatch, key) if hasattr(minibatch, key)
} }
sampled_subgraph = getattr(subgraph, self.sampler_name)( sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes, None,
self.fanout, self.fanout,
self.replace, self.replace,
self.prob_name, self.prob_name,
**kwargs, **kwargs,
) )
delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
minibatch.sampled_subgraphs[0] = sampled_subgraph minibatch.sampled_subgraphs[0] = sampled_subgraph
return minibatch return minibatch
......
...@@ -15,10 +15,10 @@ def get_hetero_graph(): ...@@ -15,10 +15,10 @@ def get_hetero_graph():
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1] # [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1, "n3": 2}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n2:e1:n3": 0, "n3:e2:n2": 1}
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([3, 5, 3, 4, 1, 2, 2, 1, 1, 2])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
edge_attributes = { edge_attributes = {
"weight": torch.FloatTensor( "weight": torch.FloatTensor(
...@@ -26,7 +26,7 @@ def get_hetero_graph(): ...@@ -26,7 +26,7 @@ def get_hetero_graph():
), ),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]), "mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]),
} }
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 1, 3, 6])
return gb.fused_csc_sampling_graph( return gb.fused_csc_sampling_graph(
indptr, indptr,
indices, indices,
...@@ -51,7 +51,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): ...@@ -51,7 +51,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
if hetero: if hetero:
itemset = gb.ItemSetDict({"n2": itemset}) itemset = gb.ItemSetDict({"n3": itemset})
else: else:
graph.type_per_edge = None graph.type_per_edge = None
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment