Unverified Commit 3d657dbf authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Define the interface of temporal neighbor sampling. (#6755)

parent f95e9df3
...@@ -321,6 +321,41 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -321,6 +321,41 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const; torch::optional<std::string> probs_name) const;
/**
* @brief Sample neighboring edges of the given nodes with a temporal
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* given, the sampled neighbors or edges of an input node must have a
* timestamp that is no later than that of the input node.
*
* @param nodes The nodes from which to sample neighbors.
* @param input_nodes_timestamp The timestamp of the nodes.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types, following the same rules as in
* SampleNeighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
* attribute, following the same rules as in SampleNeighbors.
* @param node_timestamp_attr_name An optional string specifying the name of
* the node attribute that contains the timestamp of nodes in the graph.
* @param edge_timestamp_attr_name An optional string specifying the name of
* the edge attribute that contains the timestamp of edges in the graph.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*
*/
c10::intrusive_ptr<FusedSampledSubgraph> TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids,
torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const;
/** /**
* @brief Sample negative edges by randomly choosing negative * @brief Sample negative edges by randomly choosing negative
* source-destination pairs according to a uniform distribution. For each edge * source-destination pairs according to a uniform distribution. For each edge
......
...@@ -571,6 +571,24 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -571,6 +571,24 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} }
} }
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::TemporalSampleNeighbors(
const torch::Tensor& input_nodes,
const torch::Tensor& input_nodes_timestamp,
const std::vector<int64_t>& fanouts, bool replace, bool return_eids,
torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
// TODO(zhenkun):
// 1. Get probs_or_mask.
// 2. Get the timestamp attribute for nodes of the graph
// 3. Get the timestamp attribute for edges of the graph
// 4. GetTemporalNumPickFn (New implementation)
// 5. GetTemporalPickFn (New implementation)
// 6. Call SampleNeighborsImpl (Old implementation)
return c10::intrusive_ptr<FusedSampledSubgraph>();
}
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
FusedCSCSamplingGraph::SampleNegativeEdgesUniform( FusedCSCSamplingGraph::SampleNegativeEdgesUniform(
const std::tuple<torch::Tensor, torch::Tensor>& node_pairs, const std::tuple<torch::Tensor, torch::Tensor>& node_pairs,
......
...@@ -49,6 +49,9 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -49,6 +49,9 @@ TORCH_LIBRARY(graphbolt, m) {
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes) .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph) .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
.def(
"temporal_sample_neighbors",
&FusedCSCSamplingGraph::TemporalSampleNeighbors)
.def( .def(
"sample_negative_edges_uniform", "sample_negative_edges_uniform",
&FusedCSCSamplingGraph::SampleNegativeEdgesUniform) &FusedCSCSamplingGraph::SampleNegativeEdgesUniform)
......
...@@ -830,6 +830,84 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -830,6 +830,84 @@ class FusedCSCSamplingGraph(SamplingGraph):
else: else:
return self._convert_to_sampled_subgraph(C_sampled_subgraph) return self._convert_to_sampled_subgraph(C_sampled_subgraph)
def _temporal_sample_neighbors(
self,
nodes: torch.Tensor,
input_nodes_timestamp: torch.Tensor,
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
node_timestamp_attr_name: Optional[str] = None,
edge_timestamp_attr_name: Optional[str] = None,
) -> torch.ScriptObject:
"""Temporally Sample neighboring edges of the given nodes and return the induced
subgraph.
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbors or edges of an input node must have a timestamp
that is no later than that of the input node.
Parameters
----------
nodes: torch.Tensor
IDs of the given seed nodes.
input_nodes_timestamp: torch.Tensor
Timestamps of the given seed nodes.
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors (with non-zero probability,
if weighted) will be sampled once regardless of replacement. It
is equivalent to selecting all neighbors with non-zero
probability when the fanout is >= the number of neighbors (and
replace is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
node_timestamp_attr_name: str, optional
An optional string specifying the name of an node attribute.
edge_timestamp_attr_name: str, optional
An optional string specifying the name of an edge attribute.
Returns
-------
torch.classes.graphbolt.SampledSubgraph
The sampled C subgraph.
"""
# Ensure nodes is 1-D tensor.
self._check_sampler_arguments(nodes, fanouts, probs_name)
has_original_eids = (
self.edge_attributes is not None
and ORIGINAL_EDGE_ID in self.edge_attributes
)
return self._c_csc_graph.temporal_sample_neighbors(
nodes,
input_nodes_timestamp,
fanouts.tolist(),
replace,
False,
has_original_eids,
probs_name,
node_timestamp_attr_name,
edge_timestamp_attr_name,
)
def sample_negative_edges_uniform( def sample_negative_edges_uniform(
self, edge_type, node_pairs, negative_ratio self, edge_type, node_pairs, negative_ratio
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment