"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "45572c24856ec631551dd0754960b155c137feeb"
Unverified Commit e5b92d2b authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Remove overlap graph variable hacks. (#7263)

parent d4a6f8a0
...@@ -290,6 +290,35 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -290,6 +290,35 @@ class FusedCSCSamplingGraph(SamplingGraph):
self._c_csc_graph.set_node_type_offset(node_type_offset) self._c_csc_graph.set_node_type_offset(node_type_offset)
self._node_type_offset_cached_list = None self._node_type_offset_cached_list = None
@property
def _indptr_node_type_offset_list(self) -> Optional[list]:
"""Returns the indptr node type offset list which presents the column id
space when it does not match the global id space. It is useful when we
slice a subgraph from another FusedCSCSamplingGraph.
Returns
-------
list or None
If present, returns a 1D integer list of shape
`(num_node_types + 1,)`. The list is in ascending order as nodes
of the same type have continuous IDs, and larger node IDs are
paired with larger node type IDs. The first value is 0 and last
value is the number of nodes. And nodes with IDs between
`node_type_offset_[i]~node_type_offset_[i+1]` are of type id 'i'.
"""
return (
self._indptr_node_type_offset_list_
if hasattr(self, "_indptr_node_type_offset_list_")
else None
)
@_indptr_node_type_offset_list.setter
def _indptr_node_type_offset_list(
self, indptr_node_type_offset_list: Optional[torch.Tensor]
):
"""Sets the indptr node type offset list if present."""
self._indptr_node_type_offset_list_ = indptr_node_type_offset_list
@property @property
def type_per_edge(self) -> Optional[torch.Tensor]: def type_per_edge(self) -> Optional[torch.Tensor]:
"""Returns the edge type tensor if present. """Returns the edge type tensor if present.
...@@ -665,8 +694,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -665,8 +694,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
seed_offsets = None seed_offsets = None
if isinstance(seeds, dict): if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds) seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"): elif seeds is None:
seed_offsets = self._seed_offset_list # pylint: disable=no-member seed_offsets = self._indptr_node_type_offset_list
C_sampled_subgraph = self._sample_neighbors( C_sampled_subgraph = self._sample_neighbors(
seeds, seeds,
seed_offsets, seed_offsets,
...@@ -914,8 +943,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -914,8 +943,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
seed_offsets = None seed_offsets = None
if isinstance(seeds, dict): if isinstance(seeds, dict):
seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds) seeds, seed_offsets = self._convert_to_homogeneous_nodes(seeds)
elif seeds is None and hasattr(self, "_seed_offset_list"): elif seeds is None:
seed_offsets = self._seed_offset_list # pylint: disable=no-member seed_offsets = self._indptr_node_type_offset_list
self._check_sampler_arguments(seeds, fanouts, probs_name) self._check_sampler_arguments(seeds, fanouts, probs_name)
C_sampled_subgraph = self._c_csc_graph.sample_neighbors( C_sampled_subgraph = self._c_csc_graph.sample_neighbors(
seeds, seeds,
......
...@@ -102,9 +102,9 @@ class FetchInsubgraphData(Mapper): ...@@ -102,9 +102,9 @@ class FetchInsubgraphData(Mapper):
) )
if self.prob_name is not None and probs_or_mask is not None: if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask} subgraph.edge_attributes = {self.prob_name: probs_or_mask}
subgraph._seed_offset_list = seed_offsets
minibatch.sampled_subgraphs.insert(0, subgraph) subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph
if self.stream is not None: if self.stream is not None:
minibatch.wait = torch.cuda.current_stream().record_event().wait minibatch.wait = torch.cuda.current_stream().record_event().wait
...@@ -133,7 +133,8 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): ...@@ -133,7 +133,8 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
self.prob_name = sample_per_layer_obj.prob_name self.prob_name = sample_per_layer_obj.prob_name
def _sample_per_layer_from_fetched_subgraph(self, minibatch): def _sample_per_layer_from_fetched_subgraph(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0] subgraph = minibatch._sliced_sampling_graph
delattr(minibatch, "_sliced_sampling_graph")
kwargs = { kwargs = {
key[1:]: getattr(minibatch, key) key[1:]: getattr(minibatch, key)
for key in ["_random_seed", "_seed2_contribution"] for key in ["_random_seed", "_seed2_contribution"]
...@@ -146,7 +147,7 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): ...@@ -146,7 +147,7 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
self.prob_name, self.prob_name,
**kwargs, **kwargs,
) )
minibatch.sampled_subgraphs[0] = sampled_subgraph minibatch.sampled_subgraphs.insert(0, sampled_subgraph)
return minibatch return minibatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment