Unverified Commit 7eca0612 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix attribute error when testing fused_csc_sampling_grpah. (#6842)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d6bf0387
......@@ -1108,7 +1108,7 @@ def test_temporal_sample_neighbors_homo(
node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
)
sampled_count = torch.diff(subgraph.node_pairs.indptr).tolist()
sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()
available_neighbors = _get_available_neighbors()
for i, count in enumerate(sampled_count):
if not replace:
......@@ -1116,7 +1116,7 @@ def test_temporal_sample_neighbors_homo(
else:
expect_count = fanouts[0] if len(available_neighbors[i]) > 0 else 0
assert count == expect_count
sampled_neighbors = torch.split(subgraph.node_pairs.indices, sampled_count)
sampled_neighbors = torch.split(subgraph.sampled_csc.indices, sampled_count)
for i, neighbors in enumerate(sampled_neighbors):
assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))
......@@ -1376,7 +1376,7 @@ def test_temporal_sample_neighbors_hetero(
available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.node_pairs.items():
for etype, csc in subgraph.sampled_csc.items():
stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment