"...transforms/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "72dcc1705d290a2b6084d0fed97e6f52c670ae79"
Unverified Commit 61504ec5 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Extend temporal sampling to labor (#6816)

parent 333ce36c
...@@ -335,6 +335,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -335,6 +335,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or * @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times. * without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned, * @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required. * typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge * @param probs_name An optional string specifying the name of an edge
...@@ -351,8 +354,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -351,8 +354,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighbors( c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighbors(
const torch::Tensor& input_nodes, const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp, const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids, const std::vector<int64_t>& fanouts, bool replace, bool layer,
torch::optional<std::string> probs_name, bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name, torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const; torch::optional<std::string> edge_timestamp_attr_name) const;
......
...@@ -437,6 +437,7 @@ auto GetPickFn( ...@@ -437,6 +437,7 @@ auto GetPickFn(
}; };
} }
template <SamplerType S>
auto GetTemporalPickFn( auto GetTemporalPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace, const std::vector<int64_t>& fanouts, bool replace,
...@@ -444,30 +445,31 @@ auto GetTemporalPickFn( ...@@ -444,30 +445,31 @@ auto GetTemporalPickFn(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) { const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args) {
return [&seed_timestamp, &csc_indices, &fanouts, replace, &options, return
&type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp]( [&seed_timestamp, &csc_indices, &fanouts, replace, &options,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, &type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp, args](
auto picked_data_ptr) { int64_t seed_offset, int64_t offset, int64_t num_neighbors,
// If fanouts.size() > 1, perform sampling for each edge type of each auto picked_data_ptr) {
// node; otherwise just sample once for each node with no regard of edge // If fanouts.size() > 1, perform sampling for each edge type of each
// types. // node; otherwise just sample once for each node with no regard of edge
if (fanouts.size() > 1) { // types.
return TemporalPickByEtype( if (fanouts.size() > 1) {
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, return TemporalPickByEtype(
fanouts, replace, options, type_per_edge.value(), probs_or_mask, seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
node_timestamp, edge_timestamp, picked_data_ptr); fanouts, replace, options, type_per_edge.value(), probs_or_mask,
} else { node_timestamp, edge_timestamp, args, picked_data_ptr);
int64_t num_sampled = TemporalPick( } else {
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, int64_t num_sampled = TemporalPick(
fanouts[0], replace, options, probs_or_mask, node_timestamp, seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
edge_timestamp, picked_data_ptr); fanouts[0], replace, options, probs_or_mask, node_timestamp,
if (type_per_edge) { edge_timestamp, args, picked_data_ptr);
std::sort(picked_data_ptr, picked_data_ptr + num_sampled); if (type_per_edge.has_value()) {
} std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
return num_sampled; }
} return num_sampled;
}; }
};
} }
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
...@@ -664,8 +666,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> ...@@ -664,8 +666,8 @@ c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighbors( FusedCSCSamplingGraph::TemporalSampleNeighbors(
const torch::Tensor& input_nodes, const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp, const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids, const std::vector<int64_t>& fanouts, bool replace, bool layer,
torch::optional<std::string> probs_name, bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name, torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const { torch::optional<std::string> edge_timestamp_attr_name) const {
// 1. Get probs_or_mask. // 1. Get probs_or_mask.
...@@ -684,15 +686,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( ...@@ -684,15 +686,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
// 3. Get the timestamp attribute for edges of the graph // 3. Get the timestamp attribute for edges of the graph
auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name); auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);
// 4. Call SampleNeighborsImpl // 4. Call SampleNeighborsImpl
return SampleNeighborsImpl( if (layer) {
input_nodes, return_eids, const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
GetTemporalNumPickFn( static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
input_nodes_timestamp, this->indices_, fanouts, replace, SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), return SampleNeighborsImpl(
GetTemporalPickFn( input_nodes, return_eids,
input_nodes_timestamp, this->indices_, fanouts, replace, GetTemporalNumPickFn(
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp, input_nodes_timestamp, this->indices_, fanouts, replace,
edge_timestamp)); type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp, args));
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp, args));
}
} }
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
...@@ -1130,11 +1148,12 @@ static torch::Tensor NonUniformPickOp( ...@@ -1130,11 +1148,12 @@ static torch::Tensor NonUniformPickOp(
template <typename PickedType> template <typename PickedType>
inline int64_t NonUniformPick( inline int64_t NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options, const torch::Tensor& probs_or_mask,
const torch::optional<torch::Tensor>& probs_or_mask,
PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
auto local_probs = auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors); probs_or_mask.size(0) > num_neighbors
? probs_or_mask.slice(0, offset, offset + num_neighbors)
: probs_or_mask;
auto picked_indices = NonUniformPickOp(local_probs, fanout, replace); auto picked_indices = NonUniformPickOp(local_probs, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>(); auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) { for (int i = 0; i < picked_indices.numel(); ++i) {
...@@ -1152,7 +1171,7 @@ int64_t Pick( ...@@ -1152,7 +1171,7 @@ int64_t Pick(
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) { SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
return NonUniformPick( return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask, offset, num_neighbors, fanout, replace, options, probs_or_mask.value(),
picked_data_ptr); picked_data_ptr);
} else { } else {
return UniformPick( return UniformPick(
...@@ -1160,14 +1179,14 @@ int64_t Pick( ...@@ -1160,14 +1179,14 @@ int64_t Pick(
} }
} }
template <typename PickedType> template <SamplerType S, typename PickedType>
int64_t TemporalPick( int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
auto mask = TemporalMask( auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices, utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
...@@ -1180,13 +1199,20 @@ int64_t TemporalPick( ...@@ -1180,13 +1199,20 @@ int64_t TemporalPick(
} else { } else {
masked_prob = mask.to(torch::kFloat32); masked_prob = mask.to(torch::kFloat32);
} }
auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace); if constexpr (S == SamplerType::NEIGHBOR) {
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>(); auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);
for (int i = 0; i < picked_indices.numel(); ++i) { auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
picked_data_ptr[i] = for (int i = 0; i < picked_indices.numel(); ++i) {
static_cast<PickedType>(picked_indices_ptr[i]) + offset; picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
} }
return picked_indices.numel();
} }
template <SamplerType S, typename PickedType> template <SamplerType S, typename PickedType>
...@@ -1226,7 +1252,7 @@ int64_t PickByEtype( ...@@ -1226,7 +1252,7 @@ int64_t PickByEtype(
return pick_offset; return pick_offset;
} }
template <typename PickedType> template <SamplerType S, typename PickedType>
int64_t TemporalPickByEtype( int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices, torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t seed_offset, int64_t offset, int64_t num_neighbors,
...@@ -1234,7 +1260,7 @@ int64_t TemporalPickByEtype( ...@@ -1234,7 +1260,7 @@ int64_t TemporalPickByEtype(
const torch::TensorOptions& options, const torch::Tensor& type_per_edge, const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp, const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
int64_t etype_begin = offset; int64_t etype_begin = offset;
int64_t etype_end = offset; int64_t etype_end = offset;
...@@ -1258,7 +1284,7 @@ int64_t TemporalPickByEtype( ...@@ -1258,7 +1284,7 @@ int64_t TemporalPickByEtype(
int64_t picked_count = TemporalPick( int64_t picked_count = TemporalPick(
seed_timestamp, csc_indices, seed_offset, etype_begin, seed_timestamp, csc_indices, seed_offset, etype_begin,
etype_end - etype_begin, fanout, replace, options, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, node_timestamp, edge_timestamp, probs_or_mask, node_timestamp, edge_timestamp, args,
picked_data_ptr + pick_offset); picked_data_ptr + pick_offset);
pick_offset += picked_count; pick_offset += picked_count;
} }
...@@ -1278,8 +1304,8 @@ int64_t Pick( ...@@ -1278,8 +1304,8 @@ int64_t Pick(
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (fanout < 0) { if (fanout < 0) {
return NonUniformPick( return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask, offset, num_neighbors, fanout, replace, options,
picked_data_ptr); probs_or_mask.value(), picked_data_ptr);
} else { } else {
int64_t picked_count; int64_t picked_count;
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
...@@ -1365,6 +1391,9 @@ inline int64_t LaborPick( ...@@ -1365,6 +1391,9 @@ inline int64_t LaborPick(
const ProbsType* local_probs_data = const ProbsType* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset
: nullptr; : nullptr;
if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) {
local_probs_data -= offset;
}
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] { args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data = const scalar_t* local_indices_data =
......
...@@ -860,6 +860,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -860,6 +860,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
input_nodes_timestamp, input_nodes_timestamp,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
False,
has_original_eids, has_original_eids,
probs_name, probs_name,
node_timestamp_attr_name, node_timestamp_attr_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment