Unverified Commit 95d62394 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Use `diff()` to calculate the differences for simplicity (#6884)

parent e9162491
......@@ -227,13 +227,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
original_row_node_ids, dim=0, index=indices
)
if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave(
indptr[1:] - indptr[:-1]
)
indptr = original_column_node_ids.repeat_interleave(indptr.diff())
else:
indptr = torch.arange(len(indptr) - 1).repeat_interleave(
indptr[1:] - indptr[:-1]
)
indptr = torch.arange(len(indptr) - 1).repeat_interleave(indptr.diff())
return (indices, indptr)
......
......@@ -1402,7 +1402,7 @@ def test_from_dglgraph_homogeneous():
dgl_g, is_homogeneous=True, include_original_edge_id=True
)
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
num_columns = gb_g.csc_indptr.diff()
rows = gb_g.indices
columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
......@@ -1456,11 +1456,11 @@ def test_from_dglgraph_heterogeneous():
# `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the
# node id in Hetero-DGLGraph.
num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1]
num_ntypes = gb_g.node_type_offset.diff()
reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])
# Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1]
num_columns = gb_g.csc_indptr.diff()
rows = reverse_node_id[gb_g.indices]
columns = reverse_node_id[
torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
......
......@@ -664,10 +664,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges = block.edges(etype=etype)
dst_ndoes = torch.arange(
0, len(sampled_csc[i][relation].indptr) - 1
).repeat_interleave(
sampled_csc[i][relation].indptr[1:]
- sampled_csc[i][relation].indptr[:-1]
)
).repeat_interleave(sampled_csc[i][relation].indptr.diff())
assert torch.equal(edges[0], sampled_csc[i][relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
......@@ -676,10 +673,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
dst_ndoes = torch.arange(
0, len(sampled_csc[0][reverse_relation].indptr) - 1
).repeat_interleave(
sampled_csc[0][reverse_relation].indptr[1:]
- sampled_csc[0][reverse_relation].indptr[:-1]
)
).repeat_interleave(sampled_csc[0][reverse_relation].indptr.diff())
assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
......@@ -704,9 +698,7 @@ def check_dgl_blocks_homo(minibatch, blocks):
for i, block in enumerate(blocks):
dst_ndoes = torch.arange(
0, len(sampled_csc[i].indptr) - 1
).repeat_interleave(
sampled_csc[i].indptr[1:] - sampled_csc[i].indptr[:-1]
)
).repeat_interleave(sampled_csc[i].indptr.diff())
assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
block.edges()
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment