Unverified Commit d3483fe1 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Optimize hetero sampling. (#7223)

parent 2a00cd3d
......@@ -19,8 +19,10 @@ namespace ops {
*
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param nodes The nodes from which to sample neighbors. If not provided,
* @param seeds The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
......@@ -45,6 +47,12 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
* @param node_type_to_id A dictionary mapping node type names to type IDs. The
* length of it is equal to the number of node types. The key is the node type
* name, and the value is the corresponding type ID.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs. The
* length of it is equal to the number of edge types. The key is the edge type
* name, and the value is the corresponding type ID.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
......@@ -54,10 +62,16 @@ namespace ops {
*/
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id =
torch::nullopt,
torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id =
torch::nullopt,
torch::optional<torch::Tensor> random_seed = torch::nullopt,
float seed2_contribution = .0f);
......
......@@ -298,8 +298,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param nodes The nodes from which to sample neighbors. If not provided,
* @param seeds The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type id i.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
......@@ -333,9 +335,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information.
*/
c10::intrusive_ptr<FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;
......
......@@ -51,33 +51,39 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* graph.
* @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
* @param etype_offsets Edge offsets for the sampled edges for the sampled
* edges that are sorted w.r.t. edge types.
*/
FusedSampledSubgraph(
torch::Tensor indptr, torch::Tensor indices,
torch::Tensor original_column_node_ids,
torch::optional<torch::Tensor> original_column_node_ids,
torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt)
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> etype_offsets = torch::nullopt)
: indptr(indptr),
indices(indices),
original_column_node_ids(original_column_node_ids),
original_row_node_ids(original_row_node_ids),
original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {}
type_per_edge(type_per_edge),
etype_offsets(etype_offsets) {}
FusedSampledSubgraph() = default;
/**
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* `original_column_node_ids` field.
* `original_column_node_ids` field. Its length is equal to:
* 1 + \sum_{etype} #seeds with dst_node_type(etype)
*/
torch::Tensor indptr;
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* `original_row_node_ids` field.
* `original_row_node_ids` field. The indices are sorted w.r.t. their edge
* types for the heterogenous case.
*/
torch::Tensor indices;
......@@ -86,10 +92,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column.
*
* @note This is required and the mapping relations can be inconsistent with
* column's.
* @note This is optional and the mapping relations can be inconsistent with
* column's. It can be missing when the sampling algorithm is called via a
* sliced sampled subgraph with missing seeds argument.
*/
torch::Tensor original_column_node_ids;
torch::optional<torch::Tensor> original_column_node_ids;
/**
* @brief Row's reverse node ids in the original graph. A graph structure
......@@ -104,7 +111,8 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
/**
* @brief Reverse edge ids in the original graph, the edge with id
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed.
* subgraph. This is useful when edge features are needed. The edges are
* sorted w.r.t. their edge types for the heterogenous case.
*/
torch::optional<torch::Tensor> original_edge_ids;
......@@ -112,8 +120,21 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges in the
* subgraph.
*
* @note This output is not created by the CUDA implementation as the edges
* are sorted w.r.t edge types, one has to use etype_offsets to infer the edge
* type information. This field is going to be deprecated. It can be generated
* when needed by computing gb.expand_indptr(etype_offsets).
*/
torch::optional<torch::Tensor> type_per_edge;
/**
* @brief Offsets of each etype,
* type_per_edge[etype_offsets[i]: etype_offsets[i + 1]] == i
* It has length equal to (1 + #etype), and the edges are guaranteed to be
* sorted w.r.t. their edge types.
*/
torch::optional<torch::Tensor> etype_offsets;
};
} // namespace sampling
......
......@@ -25,6 +25,7 @@
#include <type_traits>
#include "../random.h"
#include "../utils.h"
#include "./common.h"
#include "./utils.h"
......@@ -183,19 +184,26 @@ struct SegmentEndFunc {
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id,
torch::optional<torch::Tensor> random_seed_tensor,
float seed2_contribution) {
// When seed_offsets.has_value() in the hetero case, we compute the output of
// sample_neighbors _convert_to_sampled_subgraph in a fused manner so that
// _convert_to_sampled_subgraph only has to perform slices over the returned
// indptr and indices tensors to form CSC outputs for each edge type.
TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
// Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
// Assume that indptr, indices, seeds, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
auto allocator = cuda::GetAllocator();
auto num_rows =
nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1;
seeds.has_value() ? seeds.value().size(0) : indptr.size(0) - 1;
auto fanouts_pinned = torch::empty(
fanouts.size(),
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
......@@ -210,7 +218,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
fanouts_device.get(), fanouts_pinned_ptr,
sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
cuda::GetCurrentStream()));
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, seeds);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
auto max_in_degree = torch::empty(
......@@ -227,16 +235,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
max_in_degree_event.record();
torch::optional<int64_t> num_edges;
torch::Tensor sub_indptr;
if (!nodes.has_value()) {
if (!seeds.has_value()) {
num_edges = indices.size(0);
sub_indptr = indptr;
}
torch::optional<torch::Tensor> sliced_probs_or_mask;
if (probs_or_mask.has_value()) {
if (nodes.has_value()) {
if (seeds.has_value()) {
torch::Tensor sliced_probs_or_mask_tensor;
std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(),
in_degree, sliced_indptr, probs_or_mask.value(), seeds.value(),
indptr.size(0) - 2, num_edges);
sliced_probs_or_mask = sliced_probs_or_mask_tensor;
num_edges = sliced_probs_or_mask_tensor.size(0);
......@@ -246,9 +254,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
if (fanouts.size() > 1) {
torch::Tensor sliced_type_per_edge;
if (nodes.has_value()) {
if (seeds.has_value()) {
std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
in_degree, sliced_indptr, type_per_edge.value(), nodes.value(),
in_degree, sliced_indptr, type_per_edge.value(), seeds.value(),
indptr.size(0) - 2, num_edges);
} else {
sliced_type_per_edge = type_per_edge.value();
......@@ -259,7 +267,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
num_edges = sliced_type_per_edge.size(0);
}
// If sub_indptr was not computed in the two code blocks above:
if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
if (seeds.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
sub_indptr = ExclusiveCumSum(in_degree);
}
auto coo_rows = ExpandIndptrImpl(
......@@ -276,7 +284,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
......@@ -507,39 +514,153 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>());
}));
}));
if (type_per_edge) {
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The commented out torch equivalent above does not work when
// type_per_edge is on pinned memory. That is why, we have to
// reimplement it, similar to the indices gather operation above.
auto types = type_per_edge.value();
output_type_per_edge = torch::empty(
picked_eids.size(0),
picked_eids.options().dtype(types.scalar_type()));
auto index_type_per_edge_for_sampled_edges = [&] {
// The code behaves same as:
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The reimplementation is required due to the torch equivalent does
// not work when type_per_edge is on pinned memory
auto types = type_per_edge.value();
auto output = torch::empty(
picked_eids.size(0), picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(),
output_type_per_edge.value().data_ptr<scalar_t>());
types.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
}));
}
}));
}));
return output;
};
torch::optional<torch::Tensor> output_type_per_edge;
torch::optional<torch::Tensor> edge_offsets;
if (type_per_edge && seed_offsets) {
const int64_t num_etypes =
edge_type_to_id.has_value() ? edge_type_to_id->size() : 1;
// If we performed homogenous sampling on hetero graph, we have to look at
// type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation.
if (fanouts.size() == 1) {
output_type_per_edge = index_type_per_edge_for_sampled_edges();
torch::Tensor output_in_degree, sliced_output_indptr;
sliced_output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0) - 1);
std::tie(output_indptr, output_in_degree, sliced_output_indptr) =
SliceCSCIndptrHetero(
output_indptr, output_type_per_edge.value(), sliced_output_indptr,
num_etypes);
// We use num_rows to hold num_seeds * num_etypes. So, it needs to be
// updated when sampling with a single fanout value when the graph is
// heterogenous.
num_rows = sliced_output_indptr.size(0);
}
// Here, we check what are the dst node types for the given seeds so that
// we can compute the output indptr space later.
std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);
for (auto& etype_and_id : edge_type_to_id.value()) {
auto etype = etype_and_id.key();
auto id = etype_and_id.value();
auto dst_type = utils::parse_dst_ntype_from_etype(etype);
etype_id_to_dst_ntype_id[id] = node_type_to_id->at(dst_type);
}
// For each edge type, we compute the start and end offsets to index into
// indptr to form the final output_indptr.
auto indptr_offsets = torch::empty(
num_etypes * 2,
c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
auto indptr_offsets_ptr = indptr_offsets.data_ptr<int64_t>();
// We compute the indptr offsets here, right now, output_indptr is of size
// # seeds * num_etypes + 1. We can simply take slices to get correct output
// indptr. The final output_indptr is same as current indptr except that
// some intermediate values are removed to change the node ids space from
// all of the seed vertices to the node id space of the dst node type of
// each edge type.
for (int i = 0; i < num_etypes; i++) {
indptr_offsets_ptr[2 * i] = num_rows / num_etypes * i +
seed_offsets->at(etype_id_to_dst_ntype_id[i]);
indptr_offsets_ptr[2 * i + 1] =
num_rows / num_etypes * i +
seed_offsets->at(etype_id_to_dst_ntype_id[i] + 1);
}
auto permutation = torch::arange(
0, num_rows * num_etypes, num_etypes, output_indptr.options());
permutation =
permutation.remainder(num_rows) + permutation.div(num_rows, "floor");
// This permutation, when applied sorts the sampled edges with respect to
// edge types.
auto [output_in_degree, sliced_output_indptr] =
SliceCSCIndptr(output_indptr, permutation);
std::tie(output_indptr, picked_eids) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, picked_eids, permutation,
num_rows - 1, picked_eids.size(0));
edge_offsets = torch::empty(
num_etypes * 2, c10::TensorOptions()
.dtype(output_indptr.scalar_type())
.pinned_memory(true));
at::cuda::CUDAEvent edge_offsets_event;
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsEdgeOffsets", ([&] {
THRUST_CALL(
gather, indptr_offsets_ptr,
indptr_offsets_ptr + indptr_offsets.size(0),
output_indptr.data_ptr<index_t>(),
edge_offsets->data_ptr<index_t>());
}));
edge_offsets_event.record();
// The output_indices is permuted here.
std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
output_in_degree, sliced_output_indptr, output_indices, permutation,
num_rows - 1, output_indices.size(0));
std::vector<torch::Tensor> indptr_list;
for (int i = 0; i < num_etypes; i++) {
indptr_list.push_back(output_indptr.slice(
0, indptr_offsets_ptr[2 * i],
indptr_offsets_ptr[2 * i + 1] + (i == num_etypes - 1)));
}
// We form the final output indptr by concatenating pieces for different
// edge types.
output_indptr = torch::cat(indptr_list);
edge_offsets_event.synchronize();
// We read the edge_offsets here, they are in pairs but we don't need it to
// be in pairs. So we remove the duplicate information from it and turn it
// into a real offsets array.
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsEdgeOffsetsCheck", ([&] {
auto edge_offsets_ptr = edge_offsets->data_ptr<index_t>();
TORCH_CHECK(edge_offsets_ptr[0] == 0, "edge_offsets is incorrect.");
for (int i = 1; i < num_etypes; i++) {
TORCH_CHECK(
edge_offsets_ptr[2 * i - 1] == edge_offsets_ptr[2 * i],
"edge_offsets is incorrect.");
}
TORCH_CHECK(
edge_offsets_ptr[2 * num_etypes - 1] == picked_eids.size(0),
"edge_offsets is incorrect.");
for (int i = 0; i < num_etypes; i++) {
edge_offsets_ptr[i + 1] = edge_offsets_ptr[2 * i + 1];
}
}));
edge_offsets = edge_offsets->slice(0, 0, num_etypes + 1);
} else {
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
if (type_per_edge)
output_type_per_edge = index_type_per_edge_for_sampled_edges();
}
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
if (!nodes.has_value()) {
nodes = torch::arange(indptr.size(0) - 1, indices.options());
}
return c10::make_intrusive<sampling::FusedSampledSubgraph>(
output_indptr, output_indices, nodes.value(), torch::nullopt,
subgraph_reverse_edge_ids, output_type_per_edge);
output_indptr, output_indices, seeds, torch::nullopt,
subgraph_reverse_edge_ids, output_type_per_edge, edge_offsets);
}
} // namespace ops
......
......@@ -617,23 +617,24 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
auto probs_or_mask = this->EdgeAttribute(probs_name);
// If nodes does not have a value, then we expect all arguments to be resident
// on the GPU. If nodes has a value, then we expect them to be accessible from
// If seeds does not have a value, then we expect all arguments to be resident
// on the GPU. If seeds has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available.
if (((!nodes.has_value() && utils::is_on_gpu(indptr_) &&
if (((!seeds.has_value() && utils::is_on_gpu(indptr_) &&
utils::is_on_gpu(indices_) &&
(!probs_or_mask.has_value() ||
utils::is_on_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() ||
utils::is_on_gpu(type_per_edge_.value()))) ||
(nodes.has_value() && utils::is_on_gpu(nodes.value()) &&
(seeds.has_value() && utils::is_on_gpu(seeds.value()) &&
utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
(!probs_or_mask.has_value() ||
......@@ -644,11 +645,12 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, nodes, fanouts, replace, layer, return_eids,
type_per_edge_, probs_or_mask, random_seed, seed2_contribution);
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
return_eids, type_per_edge_, probs_or_mask, node_type_to_id_,
edge_type_to_id_, random_seed, seed2_contribution);
});
}
TORCH_CHECK(nodes.has_value(), "Nodes can not be None on the CPU.");
TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");
if (probs_or_mask.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper
......@@ -667,7 +669,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
nodes.value(), return_eids,
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
......@@ -686,7 +688,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
......@@ -695,7 +697,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
nodes.value(), return_eids,
seeds.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
......
......@@ -36,7 +36,8 @@ TORCH_LIBRARY(graphbolt, m) {
&FusedSampledSubgraph::original_column_node_ids)
.def_readwrite(
"original_edge_ids", &FusedSampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge);
.def_readwrite("type_per_edge", &FusedSampledSubgraph::type_per_edge)
.def_readwrite("etype_offsets", &FusedSampledSubgraph::etype_offsets);
m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
.def("index_select", &storage::OnDiskNpyArray::IndexSelect);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
......
......@@ -26,6 +26,17 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return is_on_gpu(tensor) || tensor.is_pinned();
}
/**
* @brief Parses the destination node type from a given edge type triple
* seperated with ":".
*/
inline std::string parse_dst_ntype_from_etype(std::string etype) {
auto first_seperator_it = std::find(etype.begin(), etype.end(), ':');
auto second_seperator_pos =
std::find(first_seperator_it + 1, etype.end(), ':') - etype.begin();
return etype.substr(second_seperator_pos + 1);
}
/**
* @brief Retrieves the value of the tensor at the given index.
*
......
......@@ -146,7 +146,7 @@ def _sample_neighbors_graphbolt(
return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(
nodes, fanout, replace=replace, return_eids=return_eids
nodes, None, fanout, replace=replace, return_eids=return_eids
)
# 3. Map local node IDs to global node IDs.
......
......@@ -444,7 +444,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
)}
"""
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
nodes, _ = self._convert_to_homogeneous_nodes(nodes)
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......@@ -453,22 +453,28 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = []
homogeneous_node_offsets = [0]
homogeneous_timestamps = []
offset = self._node_type_offset_list
for ntype, ids in nodes.items():
ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + offset[ntype_id])
if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype])
for ntype, ntype_id in self.node_type_to_id.items():
ids = nodes.get(ntype, [])
if len(ids) > 0:
homogeneous_nodes.append(ids + offset[ntype_id])
if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype])
homogeneous_node_offsets.append(
homogeneous_node_offsets[-1] + len(ids)
)
if timestamps is not None:
return torch.cat(homogeneous_nodes), torch.cat(
homogeneous_timestamps
)
return torch.cat(homogeneous_nodes)
return torch.cat(homogeneous_nodes), homogeneous_node_offsets
def _convert_to_sampled_subgraph(
self,
C_sampled_subgraph: torch.ScriptObject,
seed_offsets: Optional[list] = None,
) -> SampledSubgraphImpl:
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
......@@ -477,6 +483,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
original_edge_ids = C_sampled_subgraph.original_edge_ids
etype_offsets = C_sampled_subgraph.etype_offsets
if etype_offsets is not None:
etype_offsets = etype_offsets.tolist()
has_original_eids = (
self.edge_attributes is not None
......@@ -486,45 +495,78 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_edge_ids = torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], original_edge_ids
)
if type_per_edge is None:
if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else:
# UVA sampling requires us to move node_type_offset to GPU.
self.node_type_offset = self.node_type_offset.to(column.device)
# 1. Find node types for each nodes in column.
node_types = (
torch.searchsorted(self.node_type_offset, column, right=True)
- 1
)
offset = self._node_type_offset_list
original_hetero_edge_ids = {}
sub_indices = {}
sub_indptr = {}
offset = self._node_type_offset_list
# 2. For loop each node type.
for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column.
nids = torch.nonzero(node_types == ntype_id).view(-1)
nids_original_indptr = indptr[nids + 1]
if etype_offsets is None:
# UVA sampling requires us to move node_type_offset to GPU.
self.node_type_offset = self.node_type_offset.to(column.device)
# 1. Find node types for each nodes in column.
node_types = (
torch.searchsorted(
self.node_type_offset, column, right=True
)
- 1
)
for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column.
nids = torch.nonzero(node_types == ntype_id).view(-1)
nids_original_indptr = indptr[nids + 1]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
if dst_ntype != ntype:
continue
# Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] = (
indices[eids] - offset[src_ntype_id]
)
cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False
)
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
)
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
else:
edge_offsets = [0]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
if dst_ntype != ntype:
continue
# Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] = indices[eids] - offset[src_ntype_id]
cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False
)
sub_indptr[etype] = torch.cat(
(torch.tensor([0], device=indptr.device), cum_edges)
ntype_id = self.node_type_to_id[dst_ntype]
edge_offsets.append(
edge_offsets[-1]
+ seed_offsets[ntype_id + 1]
- seed_offsets[ntype_id]
)
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
ntype_id = self.node_type_to_id[dst_ntype]
sub_indptr_ = indptr[
edge_offsets[etype_id] : edge_offsets[etype_id + 1] + 1
]
sub_indptr[etype] = sub_indptr_ - sub_indptr_[0]
sub_indices[etype] = indices[
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
]
if has_original_eids:
original_hetero_edge_ids[etype] = original_edge_ids[
eids
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] -= offset[src_ntype_id]
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
sampled_csc = {
......@@ -541,7 +583,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def sample_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
......@@ -551,7 +593,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
----------
nodes: torch.Tensor or Dict[str, torch.Tensor]
seeds: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
......@@ -615,21 +657,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]),
)}
"""
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
return_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
seed_offsets = None
if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"):
seed_offsets = self._seed_offset_list # pylint: disable=no-member
C_sampled_subgraph = self._sample_neighbors(
nodes,
seeds,
seed_offsets,
fanouts,
replace=replace,
probs_name=probs_name,
return_eids=return_eids,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
)
def _check_sampler_arguments(self, nodes, fanouts, probs_name):
if nodes is not None:
......@@ -676,7 +724,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _sample_neighbors(
self,
nodes: torch.Tensor,
seeds: torch.Tensor,
seed_offsets: Optional[list],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
......@@ -687,8 +736,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
----------
nodes: torch.Tensor
seeds: torch.Tensor
IDs of the given seed nodes.
seeds_offsets: list, optional
The offsets of the given seeds,
seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
......@@ -726,9 +778,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
The sampled C subgraph.
"""
# Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name)
self._check_sampler_arguments(seeds, fanouts, probs_name)
return self._c_csc_graph.sample_neighbors(
nodes,
seeds,
seed_offsets,
fanouts.tolist(),
replace,
False, # is_labor
......@@ -740,7 +793,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def sample_layer_neighbors(
self,
nodes: Union[torch.Tensor, Dict[str, torch.Tensor]],
seeds: Union[torch.Tensor, Dict[str, torch.Tensor]],
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
......@@ -754,7 +807,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
----------
nodes: torch.Tensor or Dict[str, torch.Tensor]
seeds: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
......@@ -844,10 +897,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]),
)}
"""
if isinstance(nodes, dict):
nodes = self._convert_to_homogeneous_nodes(nodes)
self._check_sampler_arguments(nodes, fanouts, probs_name)
if random_seed is not None:
assert (
1 <= len(random_seed) <= 2
......@@ -856,12 +905,21 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert (
0 <= seed2_contribution <= 1
), "seed2_contribution should be in [0, 1]."
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
seed_offsets = None
if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"):
seed_offsets = self._seed_offset_list # pylint: disable=no-member
self._check_sampler_arguments(seeds, fanouts, probs_name)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
nodes,
seeds,
seed_offsets,
fanouts.tolist(),
replace,
True,
......@@ -870,7 +928,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
random_seed,
seed2_contribution,
)
return self._convert_to_sampled_subgraph(C_sampled_subgraph)
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
)
def temporal_sample_neighbors(
self,
......
......@@ -46,44 +46,37 @@ class FetchInsubgraphData(Mapper):
def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream):
index = minibatch._seed_nodes
if isinstance(index, dict):
for idx in index.values():
seeds = minibatch._seed_nodes
is_hetero = isinstance(seeds, dict)
if is_hetero:
for idx in seeds.values():
idx.record_stream(torch.cuda.current_stream())
index = self.graph._convert_to_homogeneous_nodes(index)
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
index.record_stream(torch.cuda.current_stream())
seeds.record_stream(torch.cuda.current_stream())
seed_offsets = None
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
return tensor
if self.graph.node_type_offset is None:
# sorting not needed.
minibatch._subgraph_seed_nodes = None
else:
index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item():
# already sorted.
minibatch._subgraph_seed_nodes = None
else:
minibatch._subgraph_seed_nodes = record_stream(
original_positions.sort()[1]
)
index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)
indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None
self.graph.indices, seeds, None
)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select_csc_with_indptr(
self.graph.type_per_edge, index, output_size
self.graph.type_per_edge, seeds, output_size
)
record_stream(type_per_edge)
else:
......@@ -94,27 +87,22 @@ class FetchInsubgraphData(Mapper):
)
if probs_or_mask is not None:
_, probs_or_mask = index_select_csc_with_indptr(
probs_or_mask, index, output_size
probs_or_mask, seeds, output_size
)
record_stream(probs_or_mask)
else:
probs_or_mask = None
if self.graph.node_type_offset is not None:
node_type_offset = torch.searchsorted(
index, self.graph.node_type_offset
)
else:
node_type_offset = None
subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}
subgraph._seed_offset_list = seed_offsets
minibatch.sampled_subgraphs.insert(0, subgraph)
......@@ -152,14 +140,12 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
if hasattr(minibatch, key)
}
sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes,
None,
self.fanout,
self.replace,
self.prob_name,
**kwargs,
)
delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
minibatch.sampled_subgraphs[0] = sampled_subgraph
return minibatch
......
......@@ -15,10 +15,10 @@ def get_hetero_graph():
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
ntypes = {"n1": 0, "n2": 1, "n3": 2}
etypes = {"n2:e1:n3": 0, "n3:e2:n2": 1}
indptr = torch.LongTensor([0, 0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([3, 5, 3, 4, 1, 2, 2, 1, 1, 2])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
edge_attributes = {
"weight": torch.FloatTensor(
......@@ -26,7 +26,7 @@ def get_hetero_graph():
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]),
}
node_type_offset = torch.LongTensor([0, 2, 5])
node_type_offset = torch.LongTensor([0, 1, 3, 6])
return gb.fused_csc_sampling_graph(
indptr,
indices,
......@@ -51,7 +51,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx())
if hetero:
itemset = gb.ItemSetDict({"n2": itemset})
itemset = gb.ItemSetDict({"n3": itemset})
else:
graph.type_per_edge = None
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment