Unverified Commit 95d62394 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Use `diff()` to calculate the differences for simplicity (#6884)

parent e9162491
...@@ -227,13 +227,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids): ...@@ -227,13 +227,9 @@ def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
original_row_node_ids, dim=0, index=indices original_row_node_ids, dim=0, index=indices
) )
if original_column_node_ids is not None: if original_column_node_ids is not None:
indptr = original_column_node_ids.repeat_interleave( indptr = original_column_node_ids.repeat_interleave(indptr.diff())
indptr[1:] - indptr[:-1]
)
else: else:
indptr = torch.arange(len(indptr) - 1).repeat_interleave( indptr = torch.arange(len(indptr) - 1).repeat_interleave(indptr.diff())
indptr[1:] - indptr[:-1]
)
return (indices, indptr) return (indices, indptr)
......
...@@ -1402,7 +1402,7 @@ def test_from_dglgraph_homogeneous(): ...@@ -1402,7 +1402,7 @@ def test_from_dglgraph_homogeneous():
dgl_g, is_homogeneous=True, include_original_edge_id=True dgl_g, is_homogeneous=True, include_original_edge_id=True
) )
# Get the COO representation of the FusedCSCSamplingGraph. # Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1] num_columns = gb_g.csc_indptr.diff()
rows = gb_g.indices rows = gb_g.indices
columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns) columns = torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
...@@ -1456,11 +1456,11 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1456,11 +1456,11 @@ def test_from_dglgraph_heterogeneous():
# `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the # `reverse_node_id` is used to map the node id in FusedCSCSamplingGraph to the
# node id in Hetero-DGLGraph. # node id in Hetero-DGLGraph.
num_ntypes = gb_g.node_type_offset[1:] - gb_g.node_type_offset[:-1] num_ntypes = gb_g.node_type_offset.diff()
reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes]) reverse_node_id = torch.cat([torch.arange(num) for num in num_ntypes])
# Get the COO representation of the FusedCSCSamplingGraph. # Get the COO representation of the FusedCSCSamplingGraph.
num_columns = gb_g.csc_indptr[1:] - gb_g.csc_indptr[:-1] num_columns = gb_g.csc_indptr.diff()
rows = reverse_node_id[gb_g.indices] rows = reverse_node_id[gb_g.indices]
columns = reverse_node_id[ columns = reverse_node_id[
torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns) torch.arange(gb_g.total_num_nodes).repeat_interleave(num_columns)
......
...@@ -664,10 +664,7 @@ def check_dgl_blocks_hetero(minibatch, blocks): ...@@ -664,10 +664,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges = block.edges(etype=etype) edges = block.edges(etype=etype)
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(sampled_csc[i][relation].indptr) - 1 0, len(sampled_csc[i][relation].indptr) - 1
).repeat_interleave( ).repeat_interleave(sampled_csc[i][relation].indptr.diff())
sampled_csc[i][relation].indptr[1:]
- sampled_csc[i][relation].indptr[:-1]
)
assert torch.equal(edges[0], sampled_csc[i][relation].indices) assert torch.equal(edges[0], sampled_csc[i][relation].indices)
assert torch.equal(edges[1], dst_ndoes) assert torch.equal(edges[1], dst_ndoes)
assert torch.equal( assert torch.equal(
...@@ -676,10 +673,7 @@ def check_dgl_blocks_hetero(minibatch, blocks): ...@@ -676,10 +673,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation)) edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(sampled_csc[0][reverse_relation].indptr) - 1 0, len(sampled_csc[0][reverse_relation].indptr) - 1
).repeat_interleave( ).repeat_interleave(sampled_csc[0][reverse_relation].indptr.diff())
sampled_csc[0][reverse_relation].indptr[1:]
- sampled_csc[0][reverse_relation].indptr[:-1]
)
assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices) assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)
assert torch.equal(edges[1], dst_ndoes) assert torch.equal(edges[1], dst_ndoes)
assert torch.equal( assert torch.equal(
...@@ -704,9 +698,7 @@ def check_dgl_blocks_homo(minibatch, blocks): ...@@ -704,9 +698,7 @@ def check_dgl_blocks_homo(minibatch, blocks):
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(sampled_csc[i].indptr) - 1 0, len(sampled_csc[i].indptr) - 1
).repeat_interleave( ).repeat_interleave(sampled_csc[i].indptr.diff())
sampled_csc[i].indptr[1:] - sampled_csc[i].indptr[:-1]
)
assert torch.equal(block.edges()[0], sampled_csc[i].indices), print( assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
block.edges() block.edges()
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment