Unverified Commit 528b041c authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Overlap feature fetcher (#6954)

parent 173257b3
...@@ -52,6 +52,7 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc) ...@@ -52,6 +52,7 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)
if(USE_CUDA) if(USE_CUDA)
file(GLOB BOLT_CUDA_SRC file(GLOB BOLT_CUDA_SRC
${BOLT_DIR}/cuda/*.cu ${BOLT_DIR}/cuda/*.cu
${BOLT_DIR}/cuda/*.cc
) )
list(APPEND BOLT_SRC ${BOLT_CUDA_SRC}) list(APPEND BOLT_SRC ${BOLT_CUDA_SRC})
if(DEFINED ENV{CUDAARCHS}) if(DEFINED ENV{CUDAARCHS})
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <numeric> #include <numeric>
#include "./common.h" #include "./common.h"
#include "./max_uva_threads.h"
#include "./utils.h" #include "./utils.h"
namespace graphbolt { namespace graphbolt {
...@@ -122,17 +123,23 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -122,17 +123,23 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
if (aligned_feature_size == 1) { if (aligned_feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads. // Use a single thread to process each output row to avoid wasting threads.
const int num_threads = cuda::FindNumThreads(return_len); const int num_threads = cuda::FindNumThreads(return_len);
const int num_blocks = (return_len + num_threads - 1) / num_threads; const int num_blocks =
(std::min(return_len, cuda::max_uva_threads.value_or(1 << 20)) +
num_threads - 1) /
num_threads;
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr, IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr); input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else { } else {
dim3 block(512, 1); constexpr int BLOCK_SIZE = 512;
dim3 block(BLOCK_SIZE, 1);
while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) { while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
block.x >>= 1; block.x >>= 1;
block.y <<= 1; block.y <<= 1;
} }
const dim3 grid((return_len + block.y - 1) / block.y); const dim3 grid(std::min(
(return_len + block.y - 1) / block.y,
cuda::max_uva_threads.value_or(1 << 20) / BLOCK_SIZE));
if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) { if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
// When feature size is smaller than GPU cache line size, use unaligned // When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient. // version for less SM usage, which is more resource efficient.
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/max_uva_threads.cc
* @brief Max uva threads variable setter function.
*/
#include "./max_uva_threads.h"
namespace graphbolt {
namespace cuda {
void set_max_uva_threads(int64_t count) { max_uva_threads = count; }
} // namespace cuda
} // namespace graphbolt
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/max_uva_threads.h
* @brief Max uva threads variable declaration.
*/
#ifndef GRAPHBOLT_MAX_UVA_THREADS_H_
#define GRAPHBOLT_MAX_UVA_THREADS_H_
#include <cstdint>
#include <optional>
namespace graphbolt {
namespace cuda {
/** @brief Set a limit on the number of CUDA threads for UVA accesses. */
inline std::optional<int64_t> max_uva_threads;
void set_max_uva_threads(int64_t count);
} // namespace cuda
} // namespace graphbolt
#endif // GRAPHBOLT_MAX_UVA_THREADS_H_
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
#include <graphbolt/serialize.h> #include <graphbolt/serialize.h>
#include <graphbolt/unique_and_compact.h> #include <graphbolt/unique_and_compact.h>
#ifdef GRAPHBOLT_USE_CUDA
#include "./cuda/max_uva_threads.h"
#endif
#include "./index_select.h" #include "./index_select.h"
#include "./random.h" #include "./random.h"
...@@ -75,6 +78,9 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -75,6 +78,9 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("index_select", &ops::IndexSelect); m.def("index_select", &ops::IndexSelect);
m.def("index_select_csc", &ops::IndexSelectCSC); m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("set_seed", &RandomEngine::SetManualSeed); m.def("set_seed", &RandomEngine::SetManualSeed);
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
#endif
} }
} // namespace sampling } // namespace sampling
......
"""Graph Bolt DataLoaders""" """Graph Bolt DataLoaders"""
from queue import Queue
import torch
import torch.utils.data import torch.utils.data
import torchdata.dataloader2.graph as dp_utils import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp import torchdata.datapipes as dp
...@@ -35,6 +38,62 @@ def _find_and_wrap_parent( ...@@ -35,6 +38,62 @@ def _find_and_wrap_parent(
) )
class EndMarker(dp.iter.IterDataPipe):
"""Used to mark the end of a datapipe and is a no-op."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
yield data
class Bufferer(dp.iter.IterDataPipe):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider increasing passing a high
value. Default is 2.
"""
def __init__(self, datapipe, buffer_size=2):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer = Queue(buffer_size)
def __iter__(self):
for data in self.datapipe:
if not self.buffer.full():
self.buffer.put(data)
else:
return_data = self.buffer.get()
self.buffer.put(data)
yield return_data
while not self.buffer.empty():
yield self.buffer.get()
class Awaiter(dp.iter.IterDataPipe):
"""Calls the wait function of all items."""
def __init__(self, datapipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
data.wait()
yield data
class MultiprocessingWrapper(dp.iter.IterDataPipe): class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing. """Wraps a datapipe with multiprocessing.
...@@ -64,6 +123,14 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe): ...@@ -64,6 +123,14 @@ class MultiprocessingWrapper(dp.iter.IterDataPipe):
yield from self.dataloader yield from self.dataloader
# There needs to be a single instance of the uva_stream, if it is created
# multiple times, it leads to multiple CUDA memory pools and memory leaks.
def _get_uva_stream():
if not hasattr(_get_uva_stream, "stream"):
_get_uva_stream.stream = torch.cuda.Stream(priority=-1)
return _get_uva_stream.stream
class DataLoader(torch.utils.data.DataLoader): class DataLoader(torch.utils.data.DataLoader):
"""Multiprocessing DataLoader. """Multiprocessing DataLoader.
...@@ -84,9 +151,26 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -84,9 +151,26 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will not shut down the worker processes after a If True, the data loader will not shut down the worker processes after a
dataset has been consumed once. This allows to maintain the workers dataset has been consumed once. This allows to maintain the workers
instances alive. instances alive.
overlap_feature_fetch : bool, optional
If True, the data loader will overlap the UVA feature fetcher operations
with the rest of operations by using an alternative CUDA stream. Default
is True.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
high value will limit the amount of overlap while setting it too low may
cause the PCI-e bandwidth to not get fully utilized. Manually tuned
default is 6144, meaning around 3-4 Streaming Multiprocessors.
""" """
def __init__(self, datapipe, num_workers=0, persistent_workers=True): def __init__(
self,
datapipe,
num_workers=0,
persistent_workers=True,
overlap_feature_fetch=True,
max_uva_threads=6144,
):
# Multiprocessing requires two modifications to the datapipe: # Multiprocessing requires two modifications to the datapipe:
# #
# 1. Insert a stage after ItemSampler to distribute the # 1. Insert a stage after ItemSampler to distribute the
...@@ -94,6 +178,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -94,6 +178,7 @@ class DataLoader(torch.utils.data.DataLoader):
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader. # of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe = EndMarker(datapipe)
datapipe_graph = dp_utils.traverse_dps(datapipe) datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph) datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
...@@ -122,7 +207,35 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -122,7 +207,35 @@ class DataLoader(torch.utils.data.DataLoader):
persistent_workers=persistent_workers, persistent_workers=persistent_workers,
) )
# (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the # (3) Overlap UVA feature fetching by buffering and using an alternative
# stream.
if (
overlap_feature_fetch
and num_workers == 0
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
feature_fetchers = dp_utils.find_dps(
datapipe_graph,
FeatureFetcher,
)
for feature_fetcher in feature_fetchers:
feature_fetcher.stream = _get_uva_stream()
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Bufferer,
buffer_size=2,
)
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Awaiter,
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread. # data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent( _find_and_wrap_parent(
datapipe_graph, datapipe_graph,
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
from typing import Dict from typing import Dict
import torch
from torch.utils.data import functional_datapipe from torch.utils.data import functional_datapipe
from .base import etype_tuple_to_str from .base import etype_tuple_to_str
...@@ -52,8 +54,9 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -52,8 +54,9 @@ class FeatureFetcher(MiniBatchTransformer):
self.feature_store = feature_store self.feature_store = feature_store
self.node_feature_keys = node_feature_keys self.node_feature_keys = node_feature_keys
self.edge_feature_keys = edge_feature_keys self.edge_feature_keys = edge_feature_keys
self.stream = None
def _read(self, data): def _read_data(self, data, stream):
""" """
Fill in the node/edge features field in data. Fill in the node/edge features field in data.
...@@ -77,6 +80,12 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -77,6 +80,12 @@ class FeatureFetcher(MiniBatchTransformer):
) or isinstance(self.edge_feature_keys, Dict) ) or isinstance(self.edge_feature_keys, Dict)
# Read Node features. # Read Node features.
input_nodes = data.node_ids() input_nodes = data.node_ids()
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
return tensor
if self.node_feature_keys and input_nodes is not None: if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous: if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items(): for type_name, feature_names in self.node_feature_keys.items():
...@@ -86,19 +95,23 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -86,19 +95,23 @@ class FeatureFetcher(MiniBatchTransformer):
for feature_name in feature_names: for feature_name in feature_names:
node_features[ node_features[
(type_name, feature_name) (type_name, feature_name)
] = self.feature_store.read( ] = record_stream(
"node", self.feature_store.read(
type_name, "node",
feature_name, type_name,
nodes, feature_name,
nodes,
)
) )
else: else:
for feature_name in self.node_feature_keys: for feature_name in self.node_feature_keys:
node_features[feature_name] = self.feature_store.read( node_features[feature_name] = record_stream(
"node", self.feature_store.read(
None, "node",
feature_name, None,
input_nodes, feature_name,
input_nodes,
)
) )
# Read Edge features. # Read Edge features.
if self.edge_feature_keys and num_layers > 0: if self.edge_feature_keys and num_layers > 0:
...@@ -124,19 +137,37 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -124,19 +137,37 @@ class FeatureFetcher(MiniBatchTransformer):
for feature_name in feature_names: for feature_name in feature_names:
edge_features[i][ edge_features[i][
(type_name, feature_name) (type_name, feature_name)
] = self.feature_store.read( ] = record_stream(
"edge", type_name, feature_name, edges self.feature_store.read(
"edge", type_name, feature_name, edges
)
) )
else: else:
for feature_name in self.edge_feature_keys: for feature_name in self.edge_feature_keys:
edge_features[i][ edge_features[i][feature_name] = record_stream(
feature_name self.feature_store.read(
] = self.feature_store.read( "edge",
"edge", None,
None, feature_name,
feature_name, original_edge_ids,
original_edge_ids, )
) )
data.set_node_features(node_features) data.set_node_features(node_features)
data.set_edge_features(edge_features) data.set_edge_features(edge_features)
return data return data
def _read(self, data):
current_stream = None
if self.stream is not None:
current_stream = torch.cuda.current_stream()
self.stream.wait_stream(current_stream)
with torch.cuda.stream(self.stream):
data = self._read_data(data, current_stream)
if self.stream is not None:
event = torch.cuda.current_stream().record_event()
def _wait():
event.wait()
data.wait = _wait
return data
import unittest
import backend as F import backend as F
import dgl import dgl
import dgl.graphbolt import dgl.graphbolt
import pytest
import torch import torch
import torch.multiprocessing as mp
from . import gb_test_utils from . import gb_test_utils
...@@ -37,3 +39,44 @@ def test_DataLoader(): ...@@ -37,3 +39,44 @@ def test_DataLoader():
num_workers=4, num_workers=4,
) )
assert len(list(dataloader)) == N // B assert len(list(dataloader)) == N // B
@unittest.skipIf(
F._default_context_str != "gpu",
reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch):
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
F.ctx()
)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = dgl.graphbolt.TorchBasedFeature(
torch.randn(200, 4, pin_memory=True)
)
features[keys[1]] = dgl.graphbolt.TorchBasedFeature(
torch.randn(200, 4, pin_memory=True)
)
feature_store = dgl.graphbolt.BasicFeatureStore(features)
datapipe = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
datapipe = datapipe.copy_to(F.ctx(), extra_attrs=["seed_nodes"])
datapipe = dgl.graphbolt.NeighborSampler(
datapipe,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
["a", "b"],
)
dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch
)
assert len(list(dataloader)) == N // B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment