Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
29c3b06d
Unverified
Commit
29c3b06d
authored
Jan 04, 2024
by
czkkkkkk
Committed by
GitHub
Jan 04, 2024
Browse files
[Graphbolt] Change the temporal filter condition of temporal sampler. (#6893)
parent
55280b67
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
11 additions
and
11 deletions
+11
-11
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+1
-1
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+2
-2
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+2
-2
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+4
-4
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+2
-2
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
29c3b06d
...
@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes with a temporal
* @brief Sample neighboring edges of the given nodes with a temporal
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* given, the sampled neighbors or edges of an input node must have a
* given, the sampled neighbors or edges of an input node must have a
* timestamp that is
no lat
er than that of the input node.
* timestamp that is
small
er than that of the input node.
*
*
* @param nodes The nodes from which to sample neighbors.
* @param nodes The nodes from which to sample neighbors.
* @param input_nodes_timestamp The timestamp of the nodes.
* @param input_nodes_timestamp The timestamp of the nodes.
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
29c3b06d
...
@@ -784,10 +784,10 @@ torch::Tensor TemporalMask(
...
@@ -784,10 +784,10 @@ torch::Tensor TemporalMask(
if
(
node_timestamp
.
has_value
())
{
if
(
node_timestamp
.
has_value
())
{
auto
neighbor_timestamp
=
auto
neighbor_timestamp
=
node_timestamp
.
value
().
index_select
(
0
,
csc_indices
.
slice
(
0
,
l
,
r
));
node_timestamp
.
value
().
index_select
(
0
,
csc_indices
.
slice
(
0
,
l
,
r
));
mask
&=
neighbor_timestamp
<
=
seed_timestamp
;
mask
&=
neighbor_timestamp
<
seed_timestamp
;
}
}
if
(
edge_timestamp
.
has_value
())
{
if
(
edge_timestamp
.
has_value
())
{
mask
&=
edge_timestamp
.
value
().
slice
(
0
,
l
,
r
)
<
=
seed_timestamp
;
mask
&=
edge_timestamp
.
value
().
slice
(
0
,
l
,
r
)
<
seed_timestamp
;
}
}
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
mask
&=
probs_or_mask
.
value
().
slice
(
0
,
l
,
r
)
!=
0
;
mask
&=
probs_or_mask
.
value
().
slice
(
0
,
l
,
r
)
!=
0
;
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
29c3b06d
...
@@ -773,8 +773,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -773,8 +773,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph.
subgraph.
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbor
s
or edge
s
of an input node must have a timestamp
the sampled neighbor or edge of an input node must have a timestamp
that is
no lat
er than that of the input node.
that is
small
er than that of the input node.
Parameters
Parameters
----------
----------
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
29c3b06d
...
@@ -898,12 +898,12 @@ def test_temporal_sample_neighbors_homo(
...
@@ -898,12 +898,12 @@ def test_temporal_sample_neighbors_homo(
neighbor
=
indices
[
j
].
item
()
neighbor
=
indices
[
j
].
item
()
if
(
if
(
use_node_timestamp
use_node_timestamp
and
(
node_timestamp
[
neighbor
]
>
seed_timestamp
[
i
]).
item
()
and
(
node_timestamp
[
neighbor
]
>
=
seed_timestamp
[
i
]).
item
()
):
):
continue
continue
if
(
if
(
use_edge_timestamp
use_edge_timestamp
and
(
edge_timestamp
[
j
]
>
seed_timestamp
[
i
]).
item
()
and
(
edge_timestamp
[
j
]
>
=
seed_timestamp
[
i
]).
item
()
):
):
continue
continue
neighbors
.
append
(
neighbor
)
neighbors
.
append
(
neighbor
)
...
@@ -1035,13 +1035,13 @@ def test_temporal_sample_neighbors_hetero(
...
@@ -1035,13 +1035,13 @@ def test_temporal_sample_neighbors_hetero(
if
(
if
(
use_node_timestamp
use_node_timestamp
and
(
and
(
node_timestamp
[
neighbor
]
>
homo_seed_timestamp
[
i
]
node_timestamp
[
neighbor
]
>
=
homo_seed_timestamp
[
i
]
).
item
()
).
item
()
):
):
continue
continue
if
(
if
(
use_edge_timestamp
use_edge_timestamp
and
(
edge_timestamp
[
j
]
>
homo_seed_timestamp
[
i
]).
item
()
and
(
edge_timestamp
[
j
]
>
=
homo_seed_timestamp
[
i
]).
item
()
):
):
continue
continue
neighbors
.
append
(
neighbor
)
neighbors
.
append
(
neighbor
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
29c3b06d
...
@@ -525,7 +525,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
...
@@ -525,7 +525,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
graph
.
edge_attributes
=
{
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
3
,)))
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
3
,)))
names
=
(
names
,
"timestamp"
)
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
...
@@ -583,7 +583,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
...
@@ -583,7 +583,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
graph
.
edge_attributes
=
{
graph
.
edge_attributes
=
{
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
"timestamp"
:
torch
.
zeros
(
graph
.
indices
.
numel
()).
to
(
F
.
ctx
())
}
}
items
=
(
items
,
torch
.
randint
(
0
,
10
,
(
2
,)))
items
=
(
items
,
torch
.
randint
(
1
,
10
,
(
2
,)))
names
=
(
names
,
"timestamp"
)
names
=
(
names
,
"timestamp"
)
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
itemset
=
gb
.
ItemSetDict
({
"n2"
:
gb
.
ItemSet
(
items
,
names
=
names
)})
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment