Unverified Commit 1193e2e8 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Dataloader `num_workers > 0` fix. (#6924)

parent 87b08393
...@@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const { ...@@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const {
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph( c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const { const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) && if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) && utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() || (!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) { utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", { GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
...@@ -616,9 +615,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -616,9 +615,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
probs_or_mask = this->EdgeAttribute(probs_name); probs_or_mask = this->EdgeAttribute(probs_name);
} }
if (!replace && utils::is_accessible_from_gpu(indptr_) && if (!replace && utils::is_on_gpu(nodes) &&
utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) && utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!probs_or_mask.has_value() || (!probs_or_mask.has_value() ||
utils::is_accessible_from_gpu(probs_or_mask.value())) && utils::is_accessible_from_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() || (!type_per_edge_.has_value() ||
......
...@@ -13,8 +13,7 @@ namespace graphbolt { ...@@ -13,8 +13,7 @@ namespace graphbolt {
namespace ops { namespace ops {
torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) { torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
if (input.is_pinned() && if (utils::is_on_gpu(index) && input.is_pinned()) {
(index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelect", c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); }); { return UVAIndexSelectImpl(input, index); });
...@@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC( ...@@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) { torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK( TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors"); indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (utils::is_accessible_from_gpu(indptr) && if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) && utils::is_accessible_from_gpu(indices)) {
utils::is_accessible_from_gpu(nodes)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl", c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); }); { return IndexSelectCSCImpl(indptr, indices, nodes); });
......
...@@ -48,8 +48,7 @@ torch::Tensor IsInCPU( ...@@ -48,8 +48,7 @@ torch::Tensor IsInCPU(
torch::Tensor IsIn( torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) { const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) && if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) {
utils::is_accessible_from_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation", c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); }); { return ops::IsIn(elements, test_elements); });
......
...@@ -19,9 +19,8 @@ namespace sampling { ...@@ -19,9 +19,8 @@ namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact( std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids, const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) { const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) && if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&
utils::is_accessible_from_gpu(dst_ids) && utils::is_on_gpu(unique_dst_ids)) {
utils::is_accessible_from_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE( GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact", c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); }); { return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
......
...@@ -12,11 +12,18 @@ ...@@ -12,11 +12,18 @@
namespace graphbolt { namespace graphbolt {
namespace utils { namespace utils {
/**
* @brief Checks whether the tensor is stored on the GPU.
*/
inline bool is_on_gpu(torch::Tensor tensor) {
return tensor.device().is_cuda();
}
/** /**
* @brief Checks whether the tensor is stored on the GPU or the pinned memory. * @brief Checks whether the tensor is stored on the GPU or the pinned memory.
*/ */
inline bool is_accessible_from_gpu(torch::Tensor tensor) { inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA; return is_on_gpu(tensor) || tensor.is_pinned();
} }
/** /**
......
...@@ -9,9 +9,6 @@ from . import gb_test_utils ...@@ -9,9 +9,6 @@ from . import gb_test_utils
def test_DataLoader(): def test_DataLoader():
# https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
mp.set_start_method("spawn", force=True)
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment