Unverified Commit 22516834 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Temporal]Fix tensor option mismatch issue (#7346)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-218.ap-northeast-1.compute.internal>
parent 697a9977
...@@ -873,11 +873,16 @@ std::pair<bool, std::vector<int64_t>> FastTemporalPick( ...@@ -873,11 +873,16 @@ std::pair<bool, std::vector<int64_t>> FastTemporalPick(
continue; continue;
} }
if (node_timestamp.has_value()) { if (node_timestamp.has_value()) {
int64_t neighbor_id = bool flag = true;
utils::GetValueByIndex<int64_t>(csc_indices, edge_id); AT_DISPATCH_INDEX_TYPES(
if (utils::GetValueByIndex<int64_t>( csc_indices.scalar_type(), "CheckNodeTimeStamp", ([&] {
node_timestamp.value(), neighbor_id) >= timestamp) int64_t neighbor_id =
continue; utils::GetValueByIndex<index_t>(csc_indices, edge_id);
if (utils::GetValueByIndex<int64_t>(
node_timestamp.value(), neighbor_id) >= timestamp)
flag = false;
}));
if (!flag) continue;
} }
if (edge_timestamp.has_value() && if (edge_timestamp.has_value() &&
utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id) >= utils::GetValueByIndex<int64_t>(edge_timestamp.value(), edge_id) >=
......
...@@ -95,11 +95,18 @@ class TemporalNeighborSampler(SubgraphSampler): ...@@ -95,11 +95,18 @@ class TemporalNeighborSampler(SubgraphSampler):
), "seeds_timestamp must be provided for temporal neighbor sampling." ), "seeds_timestamp must be provided for temporal neighbor sampling."
subgraphs = [] subgraphs = []
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types. Ensure that the dtype and device
# remain consistent with those of the existing seeds.
if isinstance(seeds, dict): if isinstance(seeds, dict):
first_val = next(iter(seeds.items()))[1]
ntypes = list(self.graph.node_type_to_id.keys()) ntypes = list(self.graph.node_type_to_id.keys())
seeds = { seeds = {
ntype: seeds.get(ntype, torch.LongTensor([])) ntype: seeds.get(
ntype,
torch.tensor(
[], dtype=first_val.dtype, device=first_val.device
),
)
for ntype in ntypes for ntype in ntypes
} }
seeds_timestamp = { seeds_timestamp = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment