Unverified Commit e8be515c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Optimize `_convert_to_sampled_subgraph`. (#6867)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 3200b88b
"""CSC format sampling graph.""" """CSC format sampling graph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from collections import defaultdict
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import torch import torch
...@@ -425,75 +424,47 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -425,75 +424,47 @@ class FusedCSCSamplingGraph(SamplingGraph):
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices) sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else: else:
# The sampled graph is a fused homogenized graph, which need to be self.node_type_offset = self.node_type_offset.to(column.device)
# converted to heterogeneous graphs. # 1. Find node types for each nodes in column.
# Pre-calculate the number of each etype node_types = (
num = {} torch.searchsorted(self.node_type_offset, column, right=True)
for etype in type_per_edge:
num[etype.item()] = num.get(etype.item(), 0) + 1
# Preallocate
subgraph_indice_position = {}
subgraph_indice = {}
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),),
dtype=indices.dtype,
device=indices.device,
)
if has_original_eids:
original_hetero_edge_ids[etype] = torch.empty(
(num.get(etype_id, 0),),
dtype=original_edge_ids.dtype,
device=original_edge_ids.device,
)
subgraph_indptr[etype] = [0]
subgraph_indice_position[etype] = 0
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for i, seed in enumerate(column):
l = indptr[i].item()
r = indptr[i + 1].item()
node_type = (
torch.searchsorted(
self.node_type_offset, seed, right=True
).item()
- 1 - 1
) )
for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype) original_hetero_edge_ids = {}
sub_indices = {}
sub_indptr = {}
# 2. For loop each node type.
for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column.
nids = torch.nonzero(node_types == ntype_id).view(-1)
nids_original_indptr = indptr[nids + 1]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
if dst_ntype != ntype:
continue
# Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
num_edges = torch.searchsorted( sub_indices[etype] = (
type_per_edge[l:r], etype_id, right=True indices[eids] - self.node_type_offset[src_ntype_id]
).item() )
end = num_edges + l cum_edges = torch.searchsorted(
subgraph_indptr[etype].append( eids, nids_original_indptr, right=False
subgraph_indptr[etype][-1] + num_edges
) )
offset = subgraph_indice_position[etype] sub_indptr[etype] = torch.cat(
subgraph_indice_position[etype] += num_edges (torch.tensor([0], device=indptr.device), cum_edges)
subgraph_indice[etype][offset : offset + num_edges] = (
indices[l:end] - self.node_type_offset[src_ntype_id]
) )
if has_original_eids: if has_original_eids:
original_hetero_edge_ids[etype][ original_hetero_edge_ids[etype] = original_edge_ids[
offset : offset + num_edges eids
] = original_edge_ids[l:end] ]
l = end
if has_original_eids: if has_original_eids:
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
sampled_csc = { sampled_csc = {
etype: CSCFormatBase( etype: CSCFormatBase(
indptr=torch.tensor( indptr=sub_indptr[etype],
subgraph_indptr[etype], device=indptr.device indices=sub_indices[etype],
),
indices=subgraph_indice[etype],
) )
for etype in self.edge_type_to_id.keys() for etype in self.edge_type_to_id.keys()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment