Unverified Commit b9cf36c3 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add a `node_type_offset_list` property. (#6886)

parent 557a8f81
...@@ -92,7 +92,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -92,7 +92,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
{'N0': 2, 'N1': 3} {'N0': 2, 'N1': 3}
""" """
offset = self.node_type_offset offset = self._node_type_offset_list
# Homogenous. # Homogenous.
if offset is None or self.node_type_to_id is None: if offset is None or self.node_type_to_id is None:
...@@ -101,7 +101,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -101,7 +101,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Heterogenous # Heterogenous
else: else:
num_nodes_per_type = { num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item() _type: (offset[_idx + 1] - offset[_idx])
for _type, _idx in self.node_type_to_id.items() for _type, _idx in self.node_type_to_id.items()
} }
...@@ -197,7 +197,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -197,7 +197,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
@property @property
def node_type_offset(self) -> Optional[torch.Tensor]: def node_type_offset(self) -> Optional[torch.Tensor]:
"""Returns the node type offset tensor if present. """Returns the node type offset tensor if present. Do not modify the
returned tensor in place.
Returns Returns
------- -------
...@@ -212,12 +213,39 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -212,12 +213,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.node_type_offset() return self._c_csc_graph.node_type_offset()
@property
def _node_type_offset_list(self) -> Optional[list]:
"""Returns the node type offset list if present.
Returns
-------
list or None
If present, returns a 1D integer list of shape
`(num_node_types + 1,)`. The list is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
if (
not hasattr(self, "_node_type_offset_cached_list")
or self._node_type_offset_cached_list is None
):
self._node_type_offset_cached_list = self.node_type_offset
if self._node_type_offset_cached_list is not None:
self._node_type_offset_cached_list = (
self._node_type_offset_cached_list.tolist()
)
return self._node_type_offset_cached_list
@node_type_offset.setter @node_type_offset.setter
def node_type_offset( def node_type_offset(
self, node_type_offset: Optional[torch.Tensor] self, node_type_offset: Optional[torch.Tensor]
) -> None: ) -> None:
"""Sets the node type offset tensor if present.""" """Sets the node type offset tensor if present."""
self._c_csc_graph.set_node_type_offset(node_type_offset) self._c_csc_graph.set_node_type_offset(node_type_offset)
self._node_type_offset_cached_list = None
@property @property
def type_per_edge(self) -> Optional[torch.Tensor]: def type_per_edge(self) -> Optional[torch.Tensor]:
...@@ -387,11 +415,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -387,11 +415,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = [] homogeneous_nodes = []
homogeneous_timestamps = [] homogeneous_timestamps = []
offset = self._node_type_offset_list
for ntype, ids in nodes.items(): for ntype, ids in nodes.items():
ntype_id = self.node_type_to_id[ntype] ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append( homogeneous_nodes.append(ids + offset[ntype_id])
ids + self.node_type_offset[ntype_id].item()
)
if timestamps is not None: if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype]) homogeneous_timestamps.append(timestamps[ntype])
if timestamps is not None: if timestamps is not None:
...@@ -424,6 +451,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -424,6 +451,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices) sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else: else:
# UVA sampling requires us to move node_type_offset to GPU.
self.node_type_offset = self.node_type_offset.to(column.device) self.node_type_offset = self.node_type_offset.to(column.device)
# 1. Find node types for each nodes in column. # 1. Find node types for each nodes in column.
node_types = ( node_types = (
...@@ -434,6 +462,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -434,6 +462,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
sub_indices = {} sub_indices = {}
sub_indptr = {} sub_indptr = {}
offset = self._node_type_offset_list
# 2. For loop each node type. # 2. For loop each node type.
for ntype, ntype_id in self.node_type_to_id.items(): for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column. # Get all nodes of a specific node type in column.
...@@ -446,9 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -446,9 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Get all edge ids of a specific edge type. # Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1) eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] = ( sub_indices[etype] = indices[eids] - offset[src_ntype_id]
indices[eids] - self.node_type_offset[src_ntype_id]
)
cum_edges = torch.searchsorted( cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False eids, nids_original_indptr, right=False
) )
...@@ -882,9 +909,9 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -882,9 +909,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
negative sampling by edge type." negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type) _, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.node_type_to_id[dst_node_type] dst_node_type_id = self.node_type_to_id[dst_node_type]
offset = self._node_type_offset_list
max_node_id = ( max_node_id = (
self.node_type_offset[dst_node_type_id + 1] offset[dst_node_type_id + 1] - offset[dst_node_type_id]
- self.node_type_offset[dst_node_type_id]
) )
else: else:
max_node_id = self.total_num_nodes max_node_id = self.total_num_nodes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment