"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "da65da5b95d733f24db94e17ce835ff25718c02c"
Unverified Commit badeaf19 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Pipelined sampling optimization (#7039)

parent 4b265390
"""Base types and utilities for Graph Bolt."""
from collections import deque
from dataclasses import dataclass
import torch
......@@ -14,6 +15,10 @@ __all__ = [
"etype_str_to_tuple",
"etype_tuple_to_str",
"CopyTo",
"FutureWaiter",
"Waiter",
"Bufferer",
"EndMarker",
"isin",
"index_select",
"expand_indptr",
......@@ -247,6 +252,76 @@ class CopyTo(IterDataPipe):
yield data
@functional_datapipe("mark_end")
class EndMarker(IterDataPipe):
"""Used to mark the end of a datapipe and is a no-op."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
yield from self.datapipe
@functional_datapipe("buffer")
class Bufferer(IterDataPipe):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer = deque(maxlen=buffer_size)
def __iter__(self):
for data in self.datapipe:
if len(self.buffer) < self.buffer.maxlen:
self.buffer.append(data)
else:
return_data = self.buffer.popleft()
self.buffer.append(data)
yield return_data
while len(self.buffer) > 0:
yield self.buffer.popleft()
@functional_datapipe("wait")
class Waiter(IterDataPipe):
"""Calls the wait function of all items."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
data.wait()
yield data
@functional_datapipe("wait_future")
class FutureWaiter(IterDataPipe):
"""Calls the result function of all items and returns their results."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
yield data.result()
@dataclass
class CSCFormatBase:
r"""Basic class representing data in Compressed Sparse Column (CSC) format.
......
"""Graph Bolt DataLoaders"""
from collections import deque
from concurrent.futures import ThreadPoolExecutor
import torch
import torch.utils.data
......@@ -9,6 +9,7 @@ import torchdata.datapipes as dp
from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.neighbor_sampler import SamplePerLayer
from .internal import datapipe_graph_to_adjlist
from .item_sampler import ItemSampler
......@@ -16,8 +17,6 @@ from .item_sampler import ItemSampler
__all__ = [
"DataLoader",
"Awaiter",
"Bufferer",
]
......@@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
return datapipe_graph
class EndMarker(dp.iter.IterDataPipe):
"""Used to mark the end of a datapipe and is a no-op."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
yield from self.datapipe
class Bufferer(dp.iter.IterDataPipe):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer = deque(maxlen=buffer_size)
def __iter__(self):
for data in self.datapipe:
if len(self.buffer) < self.buffer.maxlen:
self.buffer.append(data)
else:
return_data = self.buffer.popleft()
self.buffer.append(data)
yield return_data
while len(self.buffer) > 0:
yield self.buffer.popleft()
class Awaiter(dp.iter.IterDataPipe):
"""Calls the wait function of all items."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
data.wait()
yield data
class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing.
......@@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA feature fetcher operations
with the rest of operations by using an alternative CUDA stream. Default
is True.
overlap_graph_fetch : bool, optional
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
......@@ -170,6 +118,7 @@ class DataLoader(torch.utils.data.DataLoader):
num_workers=0,
persistent_workers=True,
overlap_feature_fetch=True,
overlap_graph_fetch=False,
max_uva_threads=6144,
):
# Multiprocessing requires two modifications to the datapipe:
......@@ -179,7 +128,7 @@ class DataLoader(torch.utils.data.DataLoader):
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe = EndMarker(datapipe)
datapipe = datapipe.mark_end()
datapipe_graph = dp_utils.traverse_dps(datapipe)
# (1) Insert minibatch distribution.
......@@ -223,7 +172,25 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
feature_fetcher,
Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
feature_fetcher.buffer(1).wait(),
)
if (
overlap_graph_fetch
and num_workers == 0
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
samplers = dp_utils.find_dps(
datapipe_graph,
SamplePerLayer,
)
executor = ThreadPoolExecutor(max_workers=1)
for sampler in samplers:
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(_get_uva_stream(), executor, 1),
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
......
"""Neighbor subgraph samplers for GraphBolt."""
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import torch
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer
from ..subgraph_sampler import SubgraphSampler
from .fused_csc_sampling_graph import fused_csc_sampling_graph
from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["NeighborSampler", "LayerNeighborSampler"]
__all__ = [
"NeighborSampler",
"LayerNeighborSampler",
"SamplePerLayer",
"SamplePerLayerFromFetchedSubgraph",
"FetchInsubgraphData",
]
@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
the provided sample_per_layer_obj has a valid prob_name, then it reads the
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
exists in the underlying graph, then the types of all the fetched edges are
read as well."""
def __init__(
self, datapipe, sample_per_layer_obj, stream=None, executor=None
):
super().__init__(datapipe, self._fetch_per_layer)
self.graph = sample_per_layer_obj.sampler.__self__
self.prob_name = sample_per_layer_obj.prob_name
self.stream = stream
if executor is None:
self.executor = ThreadPoolExecutor(max_workers=1)
else:
self.executor = executor
def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream):
index = minibatch._seed_nodes
if isinstance(index, dict):
index = self.graph._convert_to_homogeneous_nodes(index)
index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item(): # is_sorted
minibatch._subgraph_seed_nodes = None
else:
minibatch._subgraph_seed_nodes = original_positions
index.record_stream(torch.cuda.current_stream())
index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None
)
record_stream(indptr)
record_stream(indices)
output_size = len(indices)
if self.graph.type_per_edge is not None:
_, type_per_edge = index_select_csc_with_indptr(
self.graph.type_per_edge, index, output_size
)
record_stream(type_per_edge)
else:
type_per_edge = None
if self.graph.edge_attributes is not None:
probs_or_mask = self.graph.edge_attributes.get(
self.prob_name, None
)
if probs_or_mask is not None:
_, probs_or_mask = index_select_csc_with_indptr(
probs_or_mask, index, output_size
)
record_stream(probs_or_mask)
else:
probs_or_mask = None
if self.graph.node_type_offset is not None:
node_type_offset = torch.searchsorted(
index, self.graph.node_type_offset
)
else:
node_type_offset = None
subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.edge_attributes = {self.prob_name: probs_or_mask}
minibatch.sampled_subgraphs.insert(0, subgraph)
if self.stream is not None:
minibatch.wait = torch.cuda.current_stream().record_event().wait
return minibatch
def _fetch_per_layer(self, minibatch):
current_stream = None
if self.stream is not None:
current_stream = torch.cuda.current_stream()
self.stream.wait_stream(current_stream)
return self.executor.submit(
self._fetch_per_layer_impl, minibatch, current_stream
)
@functional_datapipe("sample_per_layer_from_fetched_subgraph")
class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
"""Sample neighbor edges from a graph for a single layer."""
def __init__(self, datapipe, sample_per_layer_obj):
super().__init__(datapipe, self._sample_per_layer_from_fetched_subgraph)
self.sampler_name = sample_per_layer_obj.sampler.__name__
self.fanout = sample_per_layer_obj.fanout
self.replace = sample_per_layer_obj.replace
self.prob_name = sample_per_layer_obj.prob_name
def _sample_per_layer_from_fetched_subgraph(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
sampled_subgraph = getattr(subgraph, self.sampler_name)(
minibatch._subgraph_seed_nodes,
self.fanout,
self.replace,
self.prob_name,
)
delattr(minibatch, "_subgraph_seed_nodes")
sampled_subgraph.original_column_node_ids = minibatch._seed_nodes
minibatch.sampled_subgraphs[0] = sampled_subgraph
return minibatch
@functional_datapipe("sample_per_layer")
......@@ -72,6 +206,19 @@ class CompactPerLayer(MiniBatchTransformer):
return minibatch
@functional_datapipe("fetch_and_sample")
class FetcherAndSampler(MiniBatchTransformer):
"""Overlapped graph sampling operation replacement."""
def __init__(self, sampler, stream, executor, buffer_size):
datapipe = sampler.datapipe.fetch_insubgraph_data(
sampler, stream, executor
)
datapipe = datapipe.buffer(buffer_size).wait_future().wait()
datapipe = datapipe.sample_per_layer_from_fetched_subgraph(sampler)
super().__init__(datapipe)
@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
# pylint: disable=abstract-method
......@@ -173,7 +320,8 @@ class NeighborSampler(SubgraphSampler):
datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler
)
def _prepare(self, node_type_to_id, minibatch):
@staticmethod
def _prepare(node_type_to_id, minibatch):
seeds = minibatch._seed_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
......
......@@ -29,10 +29,10 @@ class MiniBatchTransformer(Mapper):
def __init__(
self,
datapipe,
transformer,
transformer=None,
):
super().__init__(datapipe, self._transformer)
self.transformer = transformer
self.transformer = transformer or self._identity
def _transformer(self, minibatch):
minibatch = self.transformer(minibatch)
......@@ -40,3 +40,7 @@ class MiniBatchTransformer(Mapper):
minibatch, (MiniBatch,)
), "The transformer output should be an instance of MiniBatch"
return minibatch
@staticmethod
def _identity(minibatch):
return minibatch
......@@ -46,11 +46,7 @@ class SubgraphSampler(MiniBatchTransformer):
datapipe = datapipe.transform(self._preprocess)
datapipe = self.sampling_stages(datapipe, *args, **kwargs)
datapipe = datapipe.transform(self._postprocess)
super().__init__(datapipe, self._identity)
@staticmethod
def _identity(minibatch):
return minibatch
super().__init__(datapipe)
@staticmethod
def _postprocess(minibatch):
......
import unittest
from functools import partial
import backend as F
import dgl
import dgl.graphbolt as gb
import pytest
import torch
def get_hetero_graph():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
edge_attributes = {
"weight": torch.FloatTensor(
[2.5, 0, 8.4, 0, 0.4, 1.2, 2.5, 0, 8.4, 0.5]
),
"mask": torch.BoolTensor([1, 0, 1, 0, 1, 1, 1, 0, 1, 1]),
}
node_type_offset = torch.LongTensor([0, 2, 5])
return gb.fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
edge_attributes=edge_attributes,
)
@unittest.skipIf(F._default_context_str != "gpu", reason="Enabled only on GPU.")
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
def test_NeighborSampler_GraphFetch(hetero, prob_name):
items = torch.arange(3)
names = "seed_nodes"
itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx())
if hetero:
itemset = gb.ItemSetDict({"n2": itemset})
else:
graph.type_per_edge = None
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
fanout = torch.LongTensor([2])
datapipe = item_sampler.map(gb.SubgraphSampler._preprocess)
datapipe = datapipe.map(
partial(gb.NeighborSampler._prepare, graph.node_type_to_id)
)
sample_per_layer = gb.SamplePerLayer(
datapipe, graph.sample_neighbors, fanout, False, prob_name
)
compact_per_layer = sample_per_layer.compact_per_layer(True)
gb.seed(123)
expected_results = list(compact_per_layer)
datapipe = gb.FetchInsubgraphData(datapipe, sample_per_layer)
datapipe = datapipe.wait_future()
datapipe = gb.SamplePerLayerFromFetchedSubgraph(datapipe, sample_per_layer)
datapipe = datapipe.compact_per_layer(True)
gb.seed(123)
new_results = list(datapipe)
assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b)
......@@ -47,11 +47,21 @@ def test_DataLoader():
F._default_context_str != "gpu",
reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize(
"sampler_name", ["NeighborSampler", "LayerNeighborSampler"]
)
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
def test_gpu_sampling_DataLoader(
sampler_name,
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
):
N = 40
B = 4
num_layers = 2
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
F.ctx()
......@@ -68,10 +78,10 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"])
datapipe = dgl.graphbolt.NeighborSampler(
datapipe = getattr(dgl.graphbolt, sampler_name)(
datapipe,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
fanouts=[torch.LongTensor([2]) for _ in range(num_layers)],
)
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher(
......@@ -81,14 +91,18 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
)
dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch
datapipe,
overlap_feature_fetch=overlap_feature_fetch,
overlap_graph_fetch=overlap_graph_fetch,
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
if overlap_graph_fetch:
bufferer_awaiter_cnt += num_layers
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Awaiter,
dgl.graphbolt.Waiter,
)
assert len(awaiters) == bufferer_awaiter_cnt
bufferers = dp_utils.find_dps(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment