Unverified Commit 7eca0612 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix attribute error when testing fused_csc_sampling_grpah. (#6842)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d6bf0387
...@@ -1108,7 +1108,7 @@ def test_temporal_sample_neighbors_homo( ...@@ -1108,7 +1108,7 @@ def test_temporal_sample_neighbors_homo(
node_timestamp_attr_name="timestamp" if use_node_timestamp else None, node_timestamp_attr_name="timestamp" if use_node_timestamp else None,
edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None, edge_timestamp_attr_name="timestamp" if use_edge_timestamp else None,
) )
sampled_count = torch.diff(subgraph.node_pairs.indptr).tolist() sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()
available_neighbors = _get_available_neighbors() available_neighbors = _get_available_neighbors()
for i, count in enumerate(sampled_count): for i, count in enumerate(sampled_count):
if not replace: if not replace:
...@@ -1116,7 +1116,7 @@ def test_temporal_sample_neighbors_homo( ...@@ -1116,7 +1116,7 @@ def test_temporal_sample_neighbors_homo(
else: else:
expect_count = fanouts[0] if len(available_neighbors[i]) > 0 else 0 expect_count = fanouts[0] if len(available_neighbors[i]) > 0 else 0
assert count == expect_count assert count == expect_count
sampled_neighbors = torch.split(subgraph.node_pairs.indices, sampled_count) sampled_neighbors = torch.split(subgraph.sampled_csc.indices, sampled_count)
for i, neighbors in enumerate(sampled_neighbors): for i, neighbors in enumerate(sampled_neighbors):
assert set(neighbors.tolist()).issubset(set(available_neighbors[i])) assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))
...@@ -1376,7 +1376,7 @@ def test_temporal_sample_neighbors_hetero( ...@@ -1376,7 +1376,7 @@ def test_temporal_sample_neighbors_hetero(
available_neighbors = _get_available_neighbors() available_neighbors = _get_available_neighbors()
sampled_count = [0] * homo_seeds.numel() sampled_count = [0] * homo_seeds.numel()
sampled_neighbors = [[] for _ in range(homo_seeds.numel())] sampled_neighbors = [[] for _ in range(homo_seeds.numel())]
for etype, csc in subgraph.node_pairs.items(): for etype, csc in subgraph.sampled_csc.items():
stype, _, _ = etype_str_to_tuple(etype) stype, _, _ = etype_str_to_tuple(etype)
ntype_offset = ntypes_to_offset[stype] ntype_offset = ntypes_to_offset[stype]
dest_nodes = per_etype_destination_nodes[etype] dest_nodes = per_etype_destination_nodes[etype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment