Unverified Commit 93990a90 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Refactor `NeighborSamplerImpl` (#7207)

parent f0c7efa9
...@@ -230,8 +230,72 @@ class FetcherAndSampler(MiniBatchTransformer): ...@@ -230,8 +230,72 @@ class FetcherAndSampler(MiniBatchTransformer):
super().__init__(datapipe) super().__init__(datapipe)
class NeighborSamplerImpl(SubgraphSampler):
# pylint: disable=abstract-method
"""Base class for NeighborSamplers."""
# pylint: disable=useless-super-delegation
def __init__(
self,
datapipe,
graph,
fanouts,
replace,
prob_name,
deduplicate,
sampler,
):
super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
)
@staticmethod
def _prepare(node_type_to_id, minibatch):
seeds = minibatch._seed_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
for _, seed in seeds.items():
device = seed.device
dtype = seed.dtype
break
default_tensor = torch.tensor([], dtype=dtype, device=device)
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
minibatch._seed_nodes = seeds
minibatch.sampled_subgraphs = []
return minibatch
@staticmethod
def _set_input_nodes(minibatch):
minibatch.input_nodes = minibatch._seed_nodes
return minibatch
# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
)
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
datapipe = datapipe.compact_per_layer(deduplicate)
return datapipe.transform(self._set_input_nodes)
@functional_datapipe("sample_neighbor") @functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler): class NeighborSampler(NeighborSamplerImpl):
# pylint: disable=abstract-method # pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph. """Sample neighbor edges from a graph and return a subgraph.
...@@ -323,61 +387,20 @@ class NeighborSampler(SubgraphSampler): ...@@ -323,61 +387,20 @@ class NeighborSampler(SubgraphSampler):
replace=False, replace=False,
prob_name=None, prob_name=None,
deduplicate=True, deduplicate=True,
sampler=None,
): ):
if sampler is None:
sampler = graph.sample_neighbors
super().__init__( super().__init__(
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler datapipe,
) graph,
fanouts,
@staticmethod replace,
def _prepare(node_type_to_id, minibatch): prob_name,
seeds = minibatch._seed_nodes deduplicate,
# Enrich seeds with all node types. graph.sample_neighbors,
if isinstance(seeds, dict):
ntypes = list(node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
for _, seed in seeds.items():
device = seed.device
dtype = seed.dtype
break
default_tensor = torch.tensor([], dtype=dtype, device=device)
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
minibatch._seed_nodes = seeds
minibatch.sampled_subgraphs = []
return minibatch
@staticmethod
def _set_input_nodes(minibatch):
minibatch.input_nodes = minibatch._seed_nodes
return minibatch
# pylint: disable=arguments-differ
def sampling_stages(
self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
):
datapipe = datapipe.transform(
partial(self._prepare, graph.node_type_to_id)
) )
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name
)
datapipe = datapipe.compact_per_layer(deduplicate)
return datapipe.transform(self._set_input_nodes)
@functional_datapipe("sample_layer_neighbor") @functional_datapipe("sample_layer_neighbor")
class LayerNeighborSampler(NeighborSampler): class LayerNeighborSampler(NeighborSamplerImpl):
# pylint: disable=abstract-method # pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph. """Sample layer neighbor edges from a graph and return a subgraph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment