"src/vscode:/vscode.git/clone" did not exist on "cc73c60cbcdc67861e7bd105d3ab838d31803496"
Unverified Commit b9cf36c3 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add a `node_type_offset_list` property. (#6886)

parent 557a8f81
......@@ -92,7 +92,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
{'N0': 2, 'N1': 3}
"""
offset = self.node_type_offset
offset = self._node_type_offset_list
# Homogenous.
if offset is None or self.node_type_to_id is None:
......@@ -101,7 +101,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Heterogenous
else:
num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item()
_type: (offset[_idx + 1] - offset[_idx])
for _type, _idx in self.node_type_to_id.items()
}
......@@ -197,7 +197,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
@property
def node_type_offset(self) -> Optional[torch.Tensor]:
"""Returns the node type offset tensor if present.
"""Returns the node type offset tensor if present. Do not modify the
returned tensor in place.
Returns
-------
......@@ -212,12 +213,39 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""
return self._c_csc_graph.node_type_offset()
@property
def _node_type_offset_list(self) -> Optional[list]:
"""Returns the node type offset list if present.
Returns
-------
list or None
If present, returns a 1D integer list of shape
`(num_node_types + 1,)`. The list is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
if (
not hasattr(self, "_node_type_offset_cached_list")
or self._node_type_offset_cached_list is None
):
self._node_type_offset_cached_list = self.node_type_offset
if self._node_type_offset_cached_list is not None:
self._node_type_offset_cached_list = (
self._node_type_offset_cached_list.tolist()
)
return self._node_type_offset_cached_list
@node_type_offset.setter
def node_type_offset(
self, node_type_offset: Optional[torch.Tensor]
) -> None:
"""Sets the node type offset tensor if present."""
self._c_csc_graph.set_node_type_offset(node_type_offset)
self._node_type_offset_cached_list = None
@property
def type_per_edge(self) -> Optional[torch.Tensor]:
......@@ -387,11 +415,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
homogeneous_nodes = []
homogeneous_timestamps = []
offset = self._node_type_offset_list
for ntype, ids in nodes.items():
ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(
ids + self.node_type_offset[ntype_id].item()
)
homogeneous_nodes.append(ids + offset[ntype_id])
if timestamps is not None:
homogeneous_timestamps.append(timestamps[ntype])
if timestamps is not None:
......@@ -424,6 +451,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else:
# UVA sampling requires us to move node_type_offset to GPU.
self.node_type_offset = self.node_type_offset.to(column.device)
# 1. Find node types for each nodes in column.
node_types = (
......@@ -434,6 +462,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids = {}
sub_indices = {}
sub_indptr = {}
offset = self._node_type_offset_list
# 2. For loop each node type.
for ntype, ntype_id in self.node_type_to_id.items():
# Get all nodes of a specific node type in column.
......@@ -446,9 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Get all edge ids of a specific edge type.
eids = torch.nonzero(type_per_edge == etype_id).view(-1)
src_ntype_id = self.node_type_to_id[src_ntype]
sub_indices[etype] = (
indices[eids] - self.node_type_offset[src_ntype_id]
)
sub_indices[etype] = indices[eids] - offset[src_ntype_id]
cum_edges = torch.searchsorted(
eids, nids_original_indptr, right=False
)
......@@ -882,9 +909,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.node_type_to_id[dst_node_type]
offset = self._node_type_offset_list
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id]
offset[dst_node_type_id + 1] - offset[dst_node_type_id]
)
else:
max_node_id = self.total_num_nodes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment