"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "89efc607ef539ca1d9321dfcb8b0a8879e3d6c45"
Unverified Commit 61504ec5 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Extend temporal sampling to labor (#6816)

parent 333ce36c
......@@ -335,6 +335,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
......@@ -351,8 +354,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids,
torch::optional<std::string> probs_name,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const;
......
......@@ -437,6 +437,7 @@ auto GetPickFn(
};
}
template <SamplerType S>
auto GetTemporalPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
......@@ -444,30 +445,31 @@ auto GetTemporalPickFn(
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) {
return [&seed_timestamp, &csc_indices, &fanouts, replace, &options,
&type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp](
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) {
return TemporalPickByEtype(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts, replace, options, type_per_edge.value(), probs_or_mask,
node_timestamp, edge_timestamp, picked_data_ptr);
} else {
int64_t num_sampled = TemporalPick(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts[0], replace, options, probs_or_mask, node_timestamp,
edge_timestamp, picked_data_ptr);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
return num_sampled;
}
};
const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args) {
return
[&seed_timestamp, &csc_indices, &fanouts, replace, &options,
&type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp, args](
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) {
return TemporalPickByEtype(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts, replace, options, type_per_edge.value(), probs_or_mask,
node_timestamp, edge_timestamp, args, picked_data_ptr);
} else {
int64_t num_sampled = TemporalPick(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts[0], replace, options, probs_or_mask, node_timestamp,
edge_timestamp, args, picked_data_ptr);
if (type_per_edge.has_value()) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
return num_sampled;
}
};
}
template <typename NumPickFn, typename PickFn>
......@@ -664,8 +666,8 @@ c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids,
torch::optional<std::string> probs_name,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool return_eids, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
// 1. Get probs_or_mask.
......@@ -684,15 +686,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
// 3. Get the timestamp attribute for edges of the graph
auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);
// 4. Call SampleNeighborsImpl
return SampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp));
if (layer) {
const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
SamplerArgs<SamplerType::LABOR> args{indices_, random_seed, NumNodes()};
return SampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp, args));
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp, args));
}
}
std::tuple<torch::Tensor, torch::Tensor>
......@@ -1130,11 +1148,12 @@ static torch::Tensor NonUniformPickOp(
template <typename PickedType>
inline int64_t NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::TensorOptions& options, const torch::Tensor& probs_or_mask,
PickedType* picked_data_ptr) {
auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors);
probs_or_mask.size(0) > num_neighbors
? probs_or_mask.slice(0, offset, offset + num_neighbors)
: probs_or_mask;
auto picked_indices = NonUniformPickOp(local_probs, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
......@@ -1152,7 +1171,7 @@ int64_t Pick(
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask,
offset, num_neighbors, fanout, replace, options, probs_or_mask.value(),
picked_data_ptr);
} else {
return UniformPick(
......@@ -1160,14 +1179,14 @@ int64_t Pick(
}
}
template <typename PickedType>
template <SamplerType S, typename PickedType>
int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
......@@ -1180,13 +1199,20 @@ int64_t TemporalPick(
} else {
masked_prob = mask.to(torch::kFloat32);
}
auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
if constexpr (S == SamplerType::NEIGHBOR) {
auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
}
return picked_indices.numel();
}
template <SamplerType S, typename PickedType>
......@@ -1226,7 +1252,7 @@ int64_t PickByEtype(
return pick_offset;
}
template <typename PickedType>
template <SamplerType S, typename PickedType>
int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
......@@ -1234,7 +1260,7 @@ int64_t TemporalPickByEtype(
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
int64_t etype_begin = offset;
int64_t etype_end = offset;
......@@ -1258,7 +1284,7 @@ int64_t TemporalPickByEtype(
int64_t picked_count = TemporalPick(
seed_timestamp, csc_indices, seed_offset, etype_begin,
etype_end - etype_begin, fanout, replace, options,
probs_or_mask, node_timestamp, edge_timestamp,
probs_or_mask, node_timestamp, edge_timestamp, args,
picked_data_ptr + pick_offset);
pick_offset += picked_count;
}
......@@ -1278,8 +1304,8 @@ int64_t Pick(
if (probs_or_mask.has_value()) {
if (fanout < 0) {
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
offset, num_neighbors, fanout, replace, options,
probs_or_mask.value(), picked_data_ptr);
} else {
int64_t picked_count;
AT_DISPATCH_FLOATING_TYPES(
......@@ -1365,6 +1391,9 @@ inline int64_t LaborPick(
const ProbsType* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset
: nullptr;
if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) {
local_probs_data -= offset;
}
AT_DISPATCH_INTEGRAL_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data =
......
......@@ -860,6 +860,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
input_nodes_timestamp,
fanouts.tolist(),
replace,
False,
has_original_eids,
probs_name,
node_timestamp_attr_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment