Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
3d657dbf
Unverified
Commit
3d657dbf
authored
Dec 19, 2023
by
czkkkkkk
Committed by
GitHub
Dec 19, 2023
Browse files
[Graphbolt] Define the interface of temporal neighbor sampling. (#6755)
parent
f95e9df3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
134 additions
and
0 deletions
+134
-0
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+35
-0
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+18
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+3
-0
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+78
-0
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
3d657dbf
...
@@ -321,6 +321,41 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -321,6 +321,41 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
bool
replace
,
bool
layer
,
bool
return_eids
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
;
torch
::
optional
<
std
::
string
>
probs_name
)
const
;
/**
* @brief Sample neighboring edges of the given nodes with a temporal
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* given, the sampled neighbors or edges of an input node must have a
* timestamp that is no later than that of the input node.
*
* @param nodes The nodes from which to sample neighbors.
* @param input_nodes_timestamp The timestamp of the nodes.
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types, following the same rules as in
* SampleNeighbors.
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* @param probs_name An optional string specifying the name of an edge
* attribute, following the same rules as in SampleNeighbors.
* @param node_timestamp_attr_name An optional string specifying the name of
* the node attribute that contains the timestamp of nodes in the graph.
* @param edge_timestamp_attr_name An optional string specifying the name of
* the edge attribute that contains the timestamp of edges in the graph.
*
* @return An intrusive pointer to a FusedSampledSubgraph object containing
* the sampled graph's information.
*
*/
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighbors
(
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
;
/**
/**
* @brief Sample negative edges by randomly choosing negative
* @brief Sample negative edges by randomly choosing negative
* source-destination pairs according to a uniform distribution. For each edge
* source-destination pairs according to a uniform distribution. For each edge
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
3d657dbf
...
@@ -571,6 +571,24 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -571,6 +571,24 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}
}
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
TemporalSampleNeighbors
(
const
torch
::
Tensor
&
input_nodes
,
const
torch
::
Tensor
&
input_nodes_timestamp
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
std
::
string
>
node_timestamp_attr_name
,
torch
::
optional
<
std
::
string
>
edge_timestamp_attr_name
)
const
{
// TODO(zhenkun):
// 1. Get probs_or_mask.
// 2. Get the timestamp attribute for nodes of the graph
// 3. Get the timestamp attribute for edges of the graph
// 4. GetTemporalNumPickFn (New implementation)
// 5. GetTemporalPickFn (New implementation)
// 6. Call SampleNeighborsImpl (Old implementation)
return
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
();
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
FusedCSCSamplingGraph
::
SampleNegativeEdgesUniform
(
FusedCSCSamplingGraph
::
SampleNegativeEdgesUniform
(
const
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>&
node_pairs
,
const
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>&
node_pairs
,
...
...
graphbolt/src/python_binding.cc
View file @
3d657dbf
...
@@ -49,6 +49,9 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -49,6 +49,9 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"set_edge_attributes"
,
&
FusedCSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"set_edge_attributes"
,
&
FusedCSCSamplingGraph
::
SetEdgeAttributes
)
.
def
(
"in_subgraph"
,
&
FusedCSCSamplingGraph
::
InSubgraph
)
.
def
(
"in_subgraph"
,
&
FusedCSCSamplingGraph
::
InSubgraph
)
.
def
(
"sample_neighbors"
,
&
FusedCSCSamplingGraph
::
SampleNeighbors
)
.
def
(
"sample_neighbors"
,
&
FusedCSCSamplingGraph
::
SampleNeighbors
)
.
def
(
"temporal_sample_neighbors"
,
&
FusedCSCSamplingGraph
::
TemporalSampleNeighbors
)
.
def
(
.
def
(
"sample_negative_edges_uniform"
,
"sample_negative_edges_uniform"
,
&
FusedCSCSamplingGraph
::
SampleNegativeEdgesUniform
)
&
FusedCSCSamplingGraph
::
SampleNegativeEdgesUniform
)
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
3d657dbf
...
@@ -830,6 +830,84 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -830,6 +830,84 @@ class FusedCSCSamplingGraph(SamplingGraph):
else
:
else
:
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
def
_temporal_sample_neighbors
(
self
,
nodes
:
torch
.
Tensor
,
input_nodes_timestamp
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
node_timestamp_attr_name
:
Optional
[
str
]
=
None
,
edge_timestamp_attr_name
:
Optional
[
str
]
=
None
,
)
->
torch
.
ScriptObject
:
"""Temporally Sample neighboring edges of the given nodes and return the induced
subgraph.
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbors or edges of an input node must have a timestamp
that is no later than that of the input node.
Parameters
----------
nodes: torch.Tensor
IDs of the given seed nodes.
input_nodes_timestamp: torch.Tensor
Timestamps of the given seed nodes.
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
considering edge types.
- When the length is 1, it indicates that the fanout applies to
all neighbors of the node as a collective, regardless of the
edge type.
- Otherwise, the length should equal to the number of edge
types, and each fanout value corresponds to a specific edge
type of the nodes.
The value of each fanout should be >= 0 or = -1.
- When the value is -1, all neighbors (with non-zero probability,
if weighted) will be sampled once regardless of replacement. It
is equivalent to selecting all neighbors with non-zero
probability when the fanout is >= the number of neighbors (and
replace is set to false).
- When the value is a non-negative integer, it serves as a
minimum threshold for selecting neighbors.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
probs_name: str, optional
An optional string specifying the name of an edge attribute. This
attribute tensor should contain (unnormalized) probabilities
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
node_timestamp_attr_name: str, optional
An optional string specifying the name of an node attribute.
edge_timestamp_attr_name: str, optional
An optional string specifying the name of an edge attribute.
Returns
-------
torch.classes.graphbolt.SampledSubgraph
The sampled C subgraph.
"""
# Ensure nodes is 1-D tensor.
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
has_original_eids
=
(
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
return
self
.
_c_csc_graph
.
temporal_sample_neighbors
(
nodes
,
input_nodes_timestamp
,
fanouts
.
tolist
(),
replace
,
False
,
has_original_eids
,
probs_name
,
node_timestamp_attr_name
,
edge_timestamp_attr_name
,
)
def
sample_negative_edges_uniform
(
def
sample_negative_edges_uniform
(
self
,
edge_type
,
node_pairs
,
negative_ratio
self
,
edge_type
,
node_pairs
,
negative_ratio
):
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment