Unverified Commit 29c3b06d authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Change the temporal filter condition of temporal sampler. (#6893)

parent 55280b67
...@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes with a temporal * @brief Sample neighboring edges of the given nodes with a temporal
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is * constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* given, the sampled neighbors or edges of an input node must have a * given, the sampled neighbors or edges of an input node must have a
* timestamp that is no later than that of the input node. * timestamp that is smaller than that of the input node.
* *
* @param nodes The nodes from which to sample neighbors. * @param nodes The nodes from which to sample neighbors.
* @param input_nodes_timestamp The timestamp of the nodes. * @param input_nodes_timestamp The timestamp of the nodes.
......
...@@ -784,10 +784,10 @@ torch::Tensor TemporalMask( ...@@ -784,10 +784,10 @@ torch::Tensor TemporalMask(
if (node_timestamp.has_value()) { if (node_timestamp.has_value()) {
auto neighbor_timestamp = auto neighbor_timestamp =
node_timestamp.value().index_select(0, csc_indices.slice(0, l, r)); node_timestamp.value().index_select(0, csc_indices.slice(0, l, r));
mask &= neighbor_timestamp <= seed_timestamp; mask &= neighbor_timestamp < seed_timestamp;
} }
if (edge_timestamp.has_value()) { if (edge_timestamp.has_value()) {
mask &= edge_timestamp.value().slice(0, l, r) <= seed_timestamp; mask &= edge_timestamp.value().slice(0, l, r) < seed_timestamp;
} }
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
mask &= probs_or_mask.value().slice(0, l, r) != 0; mask &= probs_or_mask.value().slice(0, l, r) != 0;
......
...@@ -773,8 +773,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -773,8 +773,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph. subgraph.
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given, If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbors or edges of an input node must have a timestamp the sampled neighbor or edge of an input node must have a timestamp
that is no later than that of the input node. that is smaller than that of the input node.
Parameters Parameters
---------- ----------
......
...@@ -898,12 +898,12 @@ def test_temporal_sample_neighbors_homo( ...@@ -898,12 +898,12 @@ def test_temporal_sample_neighbors_homo(
neighbor = indices[j].item() neighbor = indices[j].item()
if ( if (
use_node_timestamp use_node_timestamp
and (node_timestamp[neighbor] > seed_timestamp[i]).item() and (node_timestamp[neighbor] >= seed_timestamp[i]).item()
): ):
continue continue
if ( if (
use_edge_timestamp use_edge_timestamp
and (edge_timestamp[j] > seed_timestamp[i]).item() and (edge_timestamp[j] >= seed_timestamp[i]).item()
): ):
continue continue
neighbors.append(neighbor) neighbors.append(neighbor)
...@@ -1035,13 +1035,13 @@ def test_temporal_sample_neighbors_hetero( ...@@ -1035,13 +1035,13 @@ def test_temporal_sample_neighbors_hetero(
if ( if (
use_node_timestamp use_node_timestamp
and ( and (
node_timestamp[neighbor] > homo_seed_timestamp[i] node_timestamp[neighbor] >= homo_seed_timestamp[i]
).item() ).item()
): ):
continue continue
if ( if (
use_edge_timestamp use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item() and (edge_timestamp[j] >= homo_seed_timestamp[i]).item()
): ):
continue continue
neighbors.append(neighbor) neighbors.append(neighbor)
......
...@@ -525,7 +525,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type): ...@@ -525,7 +525,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
graph.edge_attributes = { graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
} }
items = (items, torch.randint(0, 10, (3,))) items = (items, torch.randint(1, 10, (3,)))
names = (names, "timestamp") names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
...@@ -583,7 +583,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type): ...@@ -583,7 +583,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
graph.edge_attributes = { graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx()) "timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
} }
items = (items, torch.randint(0, 10, (2,))) items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp") names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)}) itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment