Unverified Commit 1193e2e8 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Dataloader `num_workers > 0` fix. (#6924)

parent 87b08393
......@@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const {
c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) &&
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
......@@ -616,9 +615,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
probs_or_mask = this->EdgeAttribute(probs_name);
}
if (!replace && utils::is_accessible_from_gpu(indptr_) &&
if (!replace && utils::is_on_gpu(nodes) &&
utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!probs_or_mask.has_value() ||
utils::is_accessible_from_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() ||
......
......@@ -13,8 +13,7 @@ namespace graphbolt {
namespace ops {
torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
if (input.is_pinned() &&
(index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) {
if (utils::is_on_gpu(index) && input.is_pinned()) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); });
......@@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) &&
utils::is_accessible_from_gpu(nodes)) {
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); });
......
......@@ -48,8 +48,7 @@ torch::Tensor IsInCPU(
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) &&
utils::is_accessible_from_gpu(test_elements)) {
if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); });
......
......@@ -19,9 +19,8 @@ namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) &&
utils::is_accessible_from_gpu(dst_ids) &&
utils::is_accessible_from_gpu(unique_dst_ids)) {
if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&
utils::is_on_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
......
......@@ -12,11 +12,18 @@
namespace graphbolt {
namespace utils {
/**
* @brief Checks whether the tensor is stored on the GPU.
*/
inline bool is_on_gpu(torch::Tensor tensor) {
return tensor.device().is_cuda();
}
/**
* @brief Checks whether the tensor is stored on the GPU or the pinned memory.
*/
inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA;
return is_on_gpu(tensor) || tensor.is_pinned();
}
/**
......
......@@ -9,9 +9,6 @@ from . import gb_test_utils
def test_DataLoader():
# https://pytorch.org/docs/master/notes/multiprocessing.html#cuda-in-multiprocessing
mp.set_start_method("spawn", force=True)
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment