"vscode:/vscode.git/clone" did not exist on "f08d2a8a4c7ae03de982585c7dc8a47259e553fb"
Unverified Commit e42c7fcd authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Implement Temporal Neighbor Sampling. (#6784)

parent 8a8f2b00
...@@ -508,12 +508,28 @@ int64_t NumPick( ...@@ -508,12 +508,28 @@ int64_t NumPick(
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors); int64_t num_neighbors);
int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);
int64_t NumPickByEtype( int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors); int64_t num_neighbors);
int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);
/** /**
* @brief Picks a specified number of neighbors for a node, starting from the * @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors. * given offset and having the specified number of neighbors.
...@@ -562,6 +578,16 @@ int64_t Pick( ...@@ -562,6 +578,16 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
template <typename PickedType>
int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr);
/** /**
* @brief Picks a specified number of neighbors for a node per edge type, * @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors. * starting from the given offset and having the specified number of neighbors.
...@@ -597,6 +623,17 @@ int64_t PickByEtype( ...@@ -597,6 +623,17 @@ int64_t PickByEtype(
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr); PickedType* picked_data_ptr);
template <typename PickedType>
int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr);
template < template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType, bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024> int StackSize = 1024>
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "./random.h" #include "./random.h"
#include "./shared_memory_helper.h" #include "./shared_memory_helper.h"
#include "./utils.h"
namespace { namespace {
torch::optional<torch::Dict<std::string, torch::Tensor>> TensorizeDict( torch::optional<torch::Dict<std::string, torch::Tensor>> TensorizeDict(
...@@ -349,6 +350,31 @@ auto GetNumPickFn( ...@@ -349,6 +350,31 @@ auto GetNumPickFn(
}; };
} }
auto GetTemporalNumPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) {
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return [&seed_timestamp, &csc_indices, &fanouts, replace, &probs_or_mask,
&type_per_edge, &node_timestamp, &edge_timestamp](
int64_t seed_offset, int64_t offset, int64_t num_neighbors) {
if (fanouts.size() > 1) {
return TemporalNumPickByEtype(
seed_timestamp, csc_indices, fanouts, replace, type_per_edge.value(),
probs_or_mask, node_timestamp, edge_timestamp, seed_offset, offset,
num_neighbors);
} else {
return TemporalNumPick(
seed_timestamp, csc_indices, fanouts[0], replace, probs_or_mask,
node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors);
}
};
}
/** /**
* @brief Get a lambda function which contains the sampling process. * @brief Get a lambda function which contains the sampling process.
* *
...@@ -400,6 +426,39 @@ auto GetPickFn( ...@@ -400,6 +426,39 @@ auto GetPickFn(
}; };
} }
auto GetTemporalPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) {
return [&seed_timestamp, &csc_indices, &fanouts, replace, &options,
&type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp](
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) {
return TemporalPickByEtype(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts, replace, options, type_per_edge.value(), probs_or_mask,
node_timestamp, edge_timestamp, picked_data_ptr);
} else {
int64_t num_sampled = TemporalPick(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts[0], replace, options, probs_or_mask, node_timestamp,
edge_timestamp, picked_data_ptr);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
return num_sampled;
}
};
}
template <typename NumPickFn, typename PickFn> template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph> c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl( FusedCSCSamplingGraph::SampleNeighborsImpl(
...@@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( ...@@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
torch::optional<std::string> probs_name, torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name, torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const { torch::optional<std::string> edge_timestamp_attr_name) const {
// TODO(zhenkun):
// 1. Get probs_or_mask. // 1. Get probs_or_mask.
auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
}
// 2. Get the timestamp attribute for nodes of the graph // 2. Get the timestamp attribute for nodes of the graph
auto node_timestamp = this->NodeAttribute(node_timestamp_attr_name);
// 3. Get the timestamp attribute for edges of the graph // 3. Get the timestamp attribute for edges of the graph
// 4. GetTemporalNumPickFn (New implementation) auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);
// 5. GetTemporalPickFn (New implementation) // 4. Call SampleNeighborsImpl
// 6. Call SampleNeighborsImpl (Old implementation) return SampleNeighborsImpl(
return c10::intrusive_ptr<FusedSampledSubgraph>(); input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp));
} }
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
...@@ -669,6 +745,43 @@ int64_t NumPick( ...@@ -669,6 +745,43 @@ int64_t NumPick(
return replace ? fanout : std::min(fanout, num_valid_neighbors); return replace ? fanout : std::min(fanout, num_valid_neighbors);
} }
torch::Tensor TemporalMask(
int64_t seed_timestamp, torch::Tensor csc_indices,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
std::pair<int64_t, int64_t> edge_range) {
auto [l, r] = edge_range;
torch::Tensor mask = torch::ones({r - l}, torch::kBool);
if (node_timestamp.has_value()) {
auto neighbor_timestamp =
node_timestamp.value().index_select(0, csc_indices.slice(0, l, r));
mask &= neighbor_timestamp <= seed_timestamp;
}
if (edge_timestamp.has_value()) {
mask &= edge_timestamp.value().slice(0, l, r) <= seed_timestamp;
}
if (probs_or_mask.has_value()) {
mask &= probs_or_mask.value().slice(0, l, r) != 0;
}
return mask;
}
int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics,
probs_or_mask, node_timestamp, edge_timestamp,
{offset, offset + num_neighbors});
int64_t num_valid_neighbors = utils::GetValueByIndex<int64_t>(mask.sum(), 0);
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}
int64_t NumPickByEtype( int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace, const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
...@@ -699,6 +812,40 @@ int64_t NumPickByEtype( ...@@ -699,6 +812,40 @@ int64_t NumPickByEtype(
return total_count; return total_count;
} }
int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
int64_t etype_begin = offset;
const int64_t end = offset + num_neighbors;
int64_t total_count = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "TemporalNumPickFnByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
int64_t etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
total_count += TemporalNumPick(
seed_timestamp, csc_indices, fanouts[etype], replace,
probs_or_mask, node_timestamp, edge_timestamp, seed_offset,
etype_begin, etype_end - etype_begin);
etype_begin = etype_end;
}
}));
return total_count;
}
/** /**
* @brief Perform uniform sampling of elements and return the sampled indices. * @brief Perform uniform sampling of elements and return the sampled indices.
* *
...@@ -983,6 +1130,35 @@ int64_t Pick( ...@@ -983,6 +1130,35 @@ int64_t Pick(
} }
} }
template <typename PickedType>
int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr) {
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
probs_or_mask, node_timestamp, edge_timestamp,
{offset, offset + num_neighbors});
torch::Tensor masked_prob;
if (probs_or_mask.has_value()) {
masked_prob =
probs_or_mask.value().slice(0, offset, offset + num_neighbors) * mask;
} else {
masked_prob = mask.to(torch::kFloat32);
}
auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
}
return picked_indices.numel();
}
template <SamplerType S, typename PickedType> template <SamplerType S, typename PickedType>
int64_t PickByEtype( int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
...@@ -1020,6 +1196,48 @@ int64_t PickByEtype( ...@@ -1020,6 +1196,48 @@ int64_t PickByEtype(
return pick_offset; return pick_offset;
} }
template <typename PickedType>
int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr) {
int64_t etype_begin = offset;
int64_t etype_end = offset;
int64_t pick_offset = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "TemporalPickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
if (fanout != 0) {
int64_t picked_count = TemporalPick(
seed_timestamp, csc_indices, seed_offset, etype_begin,
etype_end - etype_begin, fanout, replace, options,
probs_or_mask, node_timestamp, edge_timestamp,
picked_data_ptr + pick_offset);
pick_offset += picked_count;
}
etype_begin = etype_end;
}
}));
return pick_offset;
}
template <typename PickedType> template <typename PickedType>
int64_t Pick( int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
......
...@@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) { ...@@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA; return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA;
} }
/**
* @brief Retrieves the value of the tensor at the given index.
*
* @note If the tensor is not contiguous, it will be copied to a contiguous
* tensor.
*
* @tparam T The type of the tensor.
* @param tensor The tensor.
* @param index The index.
*
* @return T The value of the tensor at the given index.
*/
template <typename T>
T GetValueByIndex(const torch::Tensor& tensor, int64_t index) {
TORCH_CHECK(
index >= 0 && index < tensor.numel(),
"The index should be within the range of the tensor, but got index ",
index, " and tensor size ", tensor.numel());
auto contiguous_tensor = tensor.contiguous();
auto data_ptr = contiguous_tensor.data_ptr<T>();
return data_ptr[index];
}
} // namespace utils } // namespace utils
} // namespace graphbolt } // namespace graphbolt
......
...@@ -439,11 +439,18 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -439,11 +439,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
node_pairs=node_pairs, original_edge_ids=original_edge_ids node_pairs=node_pairs, original_edge_ids=original_edge_ids
) )
def _convert_to_homogeneous_nodes(self, nodes): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = [] homogeneous_nodes = []
homogeneous_timestamps = []
for ntype, ids in nodes.items(): for ntype, ids in nodes.items():
ntype_id = self.node_type_to_id[ntype] ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id]) homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype])
if timestamps is not None:
return torch.cat(homogeneous_nodes), torch.cat(
homogeneous_timestamps
)
return torch.cat(homogeneous_nodes) return torch.cat(homogeneous_nodes)
def _convert_to_sampled_subgraph( def _convert_to_sampled_subgraph(
...@@ -830,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -830,7 +837,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
else: else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _temporal_sample_neighbors( def temporal_sample_neighbors(
self, self,
nodes: torch.Tensor, nodes: torch.Tensor,
input_nodes_timestamp: torch.Tensor, input_nodes_timestamp: torch.Tensor,
...@@ -887,26 +894,39 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -887,26 +894,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
Returns Returns
------- -------
torch.classes.graphbolt.SampledSubgraph FusedSampledSubgraphImpl
The sampled C subgraph. The sampled subgraph.
""" """
if isinstance(nodes, dict):
nodes, input_nodes_timestamp = self._convert_to_homogeneous_nodes(
nodes, input_nodes_timestamp
)
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name) self._check_sampler_arguments(nodes, fanouts, probs_name)
has_original_eids = ( has_original_eids = (
self.edge_attributes is not None self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes and ORIGINAL_EDGE_ID in self.edge_attributes
) )
return self._c_csc_graph.temporal_sample_neighbors( C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors(
nodes, nodes,
input_nodes_timestamp, input_nodes_timestamp,
fanouts.tolist(), fanouts.tolist(),
replace, replace,
False,
has_original_eids, has_original_eids,
probs_name, probs_name,
node_timestamp_attr_name, node_timestamp_attr_name,
edge_timestamp_attr_name, edge_timestamp_attr_name,
) )
# Broadcast the input nodes' timestamp to the sampled neighbors.
sampled_count = torch.diff(C_sampled_subgraph.indptr)
neighbors_timestamp = input_nodes_timestamp.repeat_interleave(
sampled_count
)
return (
self._convert_to_sampled_subgraph(C_sampled_subgraph),
neighbors_timestamp,
)
def sample_negative_edges_uniform( def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio self, edge_type, node_pairs, negative_ratio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment