Unverified Commit 29c3b06d authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Change the temporal filter condition of temporal sampler. (#6893)

parent 55280b67
......@@ -325,7 +325,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes with a temporal
* constraint. If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is
* given, the sampled neighbors or edges of an input node must have a
* timestamp that is no later than that of the input node.
* timestamp that is smaller than that of the input node.
*
* @param nodes The nodes from which to sample neighbors.
* @param input_nodes_timestamp The timestamp of the nodes.
......
......@@ -784,10 +784,10 @@ torch::Tensor TemporalMask(
if (node_timestamp.has_value()) {
auto neighbor_timestamp =
node_timestamp.value().index_select(0, csc_indices.slice(0, l, r));
mask &= neighbor_timestamp <= seed_timestamp;
mask &= neighbor_timestamp < seed_timestamp;
}
if (edge_timestamp.has_value()) {
mask &= edge_timestamp.value().slice(0, l, r) <= seed_timestamp;
mask &= edge_timestamp.value().slice(0, l, r) < seed_timestamp;
}
if (probs_or_mask.has_value()) {
mask &= probs_or_mask.value().slice(0, l, r) != 0;
......
......@@ -773,8 +773,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph.
If `node_timestamp_attr_name` or `edge_timestamp_attr_name` is given,
the sampled neighbors or edges of an input node must have a timestamp
that is no later than that of the input node.
the sampled neighbor or edge of an input node must have a timestamp
that is smaller than that of the input node.
Parameters
----------
......
......@@ -898,12 +898,12 @@ def test_temporal_sample_neighbors_homo(
neighbor = indices[j].item()
if (
use_node_timestamp
and (node_timestamp[neighbor] > seed_timestamp[i]).item()
and (node_timestamp[neighbor] >= seed_timestamp[i]).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > seed_timestamp[i]).item()
and (edge_timestamp[j] >= seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
......@@ -1035,13 +1035,13 @@ def test_temporal_sample_neighbors_hetero(
if (
use_node_timestamp
and (
node_timestamp[neighbor] > homo_seed_timestamp[i]
node_timestamp[neighbor] >= homo_seed_timestamp[i]
).item()
):
continue
if (
use_edge_timestamp
and (edge_timestamp[j] > homo_seed_timestamp[i]).item()
and (edge_timestamp[j] >= homo_seed_timestamp[i]).item()
):
continue
neighbors.append(neighbor)
......
......@@ -525,7 +525,7 @@ def test_SubgraphSampler_without_dedpulication_Homo(sampler_type):
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (3,)))
items = (items, torch.randint(1, 10, (3,)))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
......@@ -583,7 +583,7 @@ def test_SubgraphSampler_without_dedpulication_Hetero(sampler_type):
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (2,)))
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment