Unverified Commit aad3bd04 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fix empty tensors may being treated as pinned (#5005)

* fix empty tensor is treated as pinned

* avoid calling cudaHostGetDevicePointer on nullptr

* update empty array

* add a comment
parent e28f0781
...@@ -309,7 +309,9 @@ ...@@ -309,7 +309,9 @@
}); });
#define CHECK_VALID_CONTEXT(VAR1, VAR2) \ #define CHECK_VALID_CONTEXT(VAR1, VAR2) \
CHECK(((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned()) \ CHECK( \
((VAR1)->ctx == (VAR2)->ctx) || (VAR1).IsPinned() || \
((VAR1).NumElements() == 0)) /* Let empty arrays pass */ \
<< "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" \ << "Expected " << (#VAR2) << "(" << (VAR2)->ctx << ")" \
<< " to have the same device " \ << " to have the same device " \
<< "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \ << "context as " << (#VAR1) << "(" << (VAR1)->ctx << "). " \
......
...@@ -149,8 +149,9 @@ class DeviceAPI { ...@@ -149,8 +149,9 @@ class DeviceAPI {
* *
* @param ptr The host memory pointer to be pinned. * @param ptr The host memory pointer to be pinned.
* @param nbytes The size to be pinned. * @param nbytes The size to be pinned.
* @return false when pinning an empty tensor. true otherwise.
*/ */
DGL_DLL virtual void PinData(void* ptr, size_t nbytes); DGL_DLL virtual bool PinData(void* ptr, size_t nbytes);
/** /**
* @brief Unpin host memory using cudaHostUnregister(). * @brief Unpin host memory using cudaHostUnregister().
......
...@@ -16,12 +16,6 @@ namespace impl { ...@@ -16,12 +16,6 @@ namespace impl {
template <DGLDeviceType XPU, typename DType, typename IdType> template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) { NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = array.Ptr<DType>();
if (array.IsPinned()) {
CUDA_CALL(cudaHostGetDevicePointer(&array_data, array.Ptr<DType>(), 0));
}
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
int64_t num_feat = 1; int64_t num_feat = 1;
...@@ -33,9 +27,13 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -33,9 +27,13 @@ NDArray IndexSelect(NDArray array, IdArray index) {
// use index->ctx for pinned array // use index->ctx for pinned array
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0) return ret; if (len == 0 || arr_len * num_feat == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));
const IdType* idx_data = static_cast<IdType*>(index->data);
cudaStream_t stream = runtime::getCurrentCUDAStream();
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
......
...@@ -43,19 +43,13 @@ NDArray CSRGetData( ...@@ -43,19 +43,13 @@ NDArray CSRGetData(
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
<< "DType does not match row's dtype."; << "DType does not match row's dtype.";
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
const IdType* indices_data = csr.indices.Ptr<IdType>(); static_cast<IdType*>(cuda::GetDevicePointer(csr.indptr));
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr; const IdType* indices_data =
if (csr.is_pinned) { static_cast<IdType*>(cuda::GetDevicePointer(csr.indices));
CUDA_CALL( const IdType* data_data =
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0)); CSRHasData(csr) ? static_cast<IdType*>(cuda::GetDevicePointer(csr.data))
CUDA_CALL( : nullptr;
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) {
CUDA_CALL(
cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
}
}
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
#include "../../array/cuda/atomic.cuh" #include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh" #include "./dgl_cub.cuh"
#include "./utils.h"
using namespace dgl::aten::cuda;
namespace dgl { namespace dgl {
using namespace cuda;
using namespace aten::cuda;
namespace aten { namespace aten {
namespace impl { namespace impl {
...@@ -248,16 +249,11 @@ COOMatrix _CSRRowWiseSamplingUniform( ...@@ -248,16 +249,11 @@ COOMatrix _CSRRowWiseSamplingUniform(
IdType* const out_cols = static_cast<IdType*>(picked_col->data); IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data); IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* in_ptr = mat.indptr.Ptr<IdType>(); const IdType* in_ptr = static_cast<IdType*>(GetDevicePointer(mat.indptr));
const IdType* in_cols = mat.indices.Ptr<IdType>(); const IdType* in_cols = static_cast<IdType*>(GetDevicePointer(mat.indices));
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr; const IdType* data = CSRHasData(mat)
if (mat.is_pinned) { ? static_cast<IdType*>(GetDevicePointer(mat.data))
CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0)); : nullptr;
CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
}
}
// compute degree // compute degree
IdType* out_deg = static_cast<IdType*>( IdType* out_deg = static_cast<IdType*>(
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
static_assert( static_assert(
CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort"); CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort");
using namespace dgl::aten::cuda;
namespace dgl { namespace dgl {
using namespace cuda;
using namespace aten::cuda;
namespace aten { namespace aten {
namespace impl { namespace impl {
...@@ -496,18 +496,12 @@ COOMatrix _CSRRowWiseSampling( ...@@ -496,18 +496,12 @@ COOMatrix _CSRRowWiseSampling(
IdType* const out_cols = static_cast<IdType*>(picked_col->data); IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data); IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* in_ptr = mat.indptr.Ptr<IdType>(); const IdType* in_ptr = static_cast<IdType*>(GetDevicePointer(mat.indptr));
const IdType* in_cols = mat.indices.Ptr<IdType>(); const IdType* in_cols = static_cast<IdType*>(GetDevicePointer(mat.indices));
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr; const IdType* data = CSRHasData(mat)
const FloatType* prob_data = prob.Ptr<FloatType>(); ? static_cast<IdType*>(GetDevicePointer(mat.data))
if (mat.is_pinned) { : nullptr;
CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0)); const FloatType* prob_data = static_cast<FloatType*>(GetDevicePointer(prob));
CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
}
CUDA_CALL(cudaHostGetDevicePointer(&prob_data, prob.Ptr<FloatType>(), 0));
}
// compute degree // compute degree
// out_deg: the size of each row in the sampled matrix // out_deg: the size of each row in the sampled matrix
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
namespace dgl { namespace dgl {
using runtime::NDArray; using runtime::NDArray;
using namespace cuda;
namespace aten { namespace aten {
namespace impl { namespace impl {
...@@ -61,14 +62,10 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) { ...@@ -61,14 +62,10 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const int nt = dgl::cuda::FindNumThreads(rstlen); const int nt = dgl::cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr; const IdType* data = nullptr;
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
const IdType* indices_data = csr.indices.Ptr<IdType>(); static_cast<IdType*>(GetDevicePointer(csr.indptr));
if (csr.is_pinned) { const IdType* indices_data =
CUDA_CALL( static_cast<IdType*>(GetDevicePointer(csr.indices));
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, dgl::cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data,
...@@ -155,11 +152,8 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) { ...@@ -155,11 +152,8 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto len = rows->shape[0]; const auto len = rows->shape[0];
const IdType* vid_data = rows.Ptr<IdType>(); const IdType* vid_data = rows.Ptr<IdType>();
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
if (csr.is_pinned) { static_cast<IdType*>(GetDevicePointer(csr.indptr));
CUDA_CALL(
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
}
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx); NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data); IdType* rst_data = static_cast<IdType*>(rst->data);
const int nt = dgl::cuda::FindNumThreads(len); const int nt = dgl::cuda::FindNumThreads(len);
...@@ -267,19 +261,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) { ...@@ -267,19 +261,13 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
// Copy indices. // Copy indices.
IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx); IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
const IdType* indices_data = csr.indices.Ptr<IdType>(); static_cast<IdType*>(GetDevicePointer(csr.indptr));
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr; const IdType* indices_data =
if (csr.is_pinned) { static_cast<IdType*>(GetDevicePointer(csr.indices));
CUDA_CALL( const IdType* data_data =
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0)); CSRHasData(csr) ? static_cast<IdType*>(GetDevicePointer(csr.data))
CUDA_CALL( : nullptr;
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) {
CUDA_CALL(
cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
}
}
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data, _SegmentCopyKernel, nb, nt, 0, stream, indptr_data, indices_data,
...@@ -381,14 +369,10 @@ std::vector<NDArray> CSRGetDataAndIndices( ...@@ -381,14 +369,10 @@ std::vector<NDArray> CSRGetDataAndIndices(
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
const IdType* indices_data = csr.indices.Ptr<IdType>(); static_cast<IdType*>(GetDevicePointer(csr.indptr));
if (csr.is_pinned) { const IdType* indices_data =
CUDA_CALL( static_cast<IdType*>(GetDevicePointer(csr.indices));
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// Generate a 0-1 mask for matched (row, col) positions. // Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx); IdArray mask = Full(0, nnz, nbits, ctx);
...@@ -618,14 +602,10 @@ CSRMatrix CSRSliceMatrix( ...@@ -618,14 +602,10 @@ CSRMatrix CSRSliceMatrix(
hashmap.Insert(key[i]); hashmap.Insert(key[i]);
}); });
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data =
const IdType* indices_data = csr.indices.Ptr<IdType>(); static_cast<IdType*>(GetDevicePointer(csr.indptr));
if (csr.is_pinned) { const IdType* indices_data =
CUDA_CALL( static_cast<IdType*>(GetDevicePointer(csr.indices));
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// Execute SegmentMaskColKernel // Execute SegmentMaskColKernel
const int64_t num_rows = csr.num_rows; const int64_t num_rows = csr.num_rows;
......
...@@ -264,6 +264,14 @@ void MaskSelect( ...@@ -264,6 +264,14 @@ void MaskSelect(
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
} }
inline void* GetDevicePointer(runtime::NDArray array) {
void* ptr = array->data;
if (array.IsPinned()) {
CUDA_CALL(cudaHostGetDevicePointer(&ptr, ptr, 0));
}
return ptr;
}
} // namespace cuda } // namespace cuda
} // namespace dgl } // namespace dgl
......
...@@ -25,8 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -25,8 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
std::vector<int64_t> shape{len}; std::vector<int64_t> shape{len};
CHECK(array.IsPinned()); CHECK(array.IsPinned());
const DType* array_data = nullptr; const DType* array_data = static_cast<DType*>(cuda::GetDevicePointer(array));
CUDA_CALL(cudaHostGetDevicePointer(&array_data, array.Ptr<DType>(), 0));
CHECK_EQ(index->ctx.device_type, kDGLCUDA); CHECK_EQ(index->ctx.device_type, kDGLCUDA);
for (int d = 1; d < array->ndim; ++d) { for (int d = 1; d < array->ndim; ++d) {
...@@ -35,7 +34,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -35,7 +34,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
} }
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx); NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0) return ret; if (len == 0 || arr_len * num_feat == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
if (num_feat == 1) { if (num_feat == 1) {
...@@ -85,8 +84,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) { ...@@ -85,8 +84,7 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
std::vector<int64_t> shape{len}; std::vector<int64_t> shape{len};
CHECK(dest.IsPinned()); CHECK(dest.IsPinned());
DType* dest_data = nullptr; DType* dest_data = static_cast<DType*>(cuda::GetDevicePointer(dest));
CUDA_CALL(cudaHostGetDevicePointer(&dest_data, dest.Ptr<DType>(), 0));
CHECK_EQ(index->ctx.device_type, kDGLCUDA); CHECK_EQ(index->ctx.device_type, kDGLCUDA);
CHECK_EQ(source->ctx.device_type, kDGLCUDA); CHECK_EQ(source->ctx.device_type, kDGLCUDA);
......
...@@ -121,8 +121,9 @@ void DeviceAPI::SyncStreamFromTo( ...@@ -121,8 +121,9 @@ void DeviceAPI::SyncStreamFromTo(
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
void DeviceAPI::PinData(void* ptr, size_t nbytes) { bool DeviceAPI::PinData(void* ptr, size_t nbytes) {
LOG(FATAL) << "Device does not support cudaHostRegister api."; LOG(FATAL) << "Device does not support cudaHostRegister api.";
return false;
} }
void DeviceAPI::UnpinData(void* ptr) { void DeviceAPI::UnpinData(void* ptr) {
......
...@@ -211,10 +211,11 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -211,10 +211,11 @@ class CUDADeviceAPI final : public DeviceAPI {
* The pinned memory can be seen by all CUDA contexts, * The pinned memory can be seen by all CUDA contexts,
* not just the one that performed the allocation * not just the one that performed the allocation
*/ */
void PinData(void* ptr, size_t nbytes) { bool PinData(void* ptr, size_t nbytes) override {
// prevent users from pinning empty tensors or graphs // prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0) return; if (ptr == nullptr || nbytes == 0) return false;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
return true;
} }
void UnpinData(void* ptr) { void UnpinData(void* ptr) {
......
...@@ -211,8 +211,8 @@ void NDArray::PinContainer(NDArray::Container* ptr) { ...@@ -211,8 +211,8 @@ void NDArray::PinContainer(NDArray::Container* ptr) {
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDGLCPU) CHECK_EQ(tensor->ctx.device_type, kDGLCPU)
<< "Only NDArray on CPU can be pinned"; << "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor)); ptr->pinned_by_dgl_ =
ptr->pinned_by_dgl_ = true; DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor));
} }
void NDArray::UnpinContainer(NDArray::Container* ptr) { void NDArray::UnpinContainer(NDArray::Container* ptr) {
......
import unittest import unittest
import backend as F import backend as F
import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
import scipy.sparse as ssp import scipy.sparse as ssp
from test_utils import parametrize_idtype from test_utils import parametrize_idtype
import dgl
D = 5 D = 5
...@@ -130,6 +130,32 @@ def create_test_heterograph(idtype): ...@@ -130,6 +130,32 @@ def create_test_heterograph(idtype):
return g return g
def create_test_heterograph2(idtype):
"""test heterograph from the docstring, with an empty relation"""
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "wishes", "game"): ([0, 2], [1, 0]),
("developer", "develops", "game"): ([], []),
},
idtype=idtype,
device=F.ctx(),
)
for etype in g.etypes:
g.edges[etype].data["weight"] = F.randn((g.num_edges(etype),))
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@unittest.skipIf( @unittest.skipIf(
dgl.backend.backend_name == "mxnet", dgl.backend.backend_name == "mxnet",
reason="MXNet doesn't support bool tensor", reason="MXNet doesn't support bool tensor",
...@@ -788,14 +814,14 @@ def test_subframes(parent_idx_device, child_device): ...@@ -788,14 +814,14 @@ def test_subframes(parent_idx_device, child_device):
@unittest.skipIf( @unittest.skipIf(
F._default_context_str != "gpu", reason="UVA only available on GPU" F._default_context_str != "gpu", reason="UVA only available on GPU"
) )
@pytest.mark.parametrize("device", [F.cpu(), F.cuda()])
@unittest.skipIf( @unittest.skipIf(
dgl.backend.backend_name != "pytorch", dgl.backend.backend_name != "pytorch",
reason="UVA only supported for PyTorch", reason="UVA only supported for PyTorch",
) )
@pytest.mark.parametrize("device", [F.cpu(), F.cuda()])
@parametrize_idtype @parametrize_idtype
def test_uva_subgraph(idtype, device): def test_uva_subgraph(idtype, device):
g = create_test_heterograph(idtype) g = create_test_heterograph2(idtype)
g = g.to(F.cpu()) g = g.to(F.cpu())
g.create_formats_() g.create_formats_()
g.pin_memory_() g.pin_memory_()
...@@ -805,16 +831,13 @@ def test_uva_subgraph(idtype, device): ...@@ -805,16 +831,13 @@ def test_uva_subgraph(idtype, device):
assert g.edge_subgraph(edge_indices).device == device assert g.edge_subgraph(edge_indices).device == device
assert g.in_subgraph(indices).device == device assert g.in_subgraph(indices).device == device
assert g.out_subgraph(indices).device == device assert g.out_subgraph(indices).device == device
if dgl.backend.backend_name != "tensorflow": assert g.khop_in_subgraph(indices, 1)[0].device == device
# (BarclayII) Most of Tensorflow functions somehow do not preserve device: a CPU tensor assert g.khop_out_subgraph(indices, 1)[0].device == device
# becomes a GPU tensor after operations such as concat(), unique() or even sin().
# Not sure what should be the best fix.
assert g.khop_in_subgraph(indices, 1)[0].device == device
assert g.khop_out_subgraph(indices, 1)[0].device == device
assert g.sample_neighbors(indices, 1).device == device assert g.sample_neighbors(indices, 1).device == device
g.unpin_memory_() g.unpin_memory_()
if __name__ == "__main__": if __name__ == "__main__":
test_edge_subgraph() test_edge_subgraph()
# test_uva_subgraph(F.int64, F.cpu()) test_uva_subgraph(F.int64, F.cpu())
test_uva_subgraph(F.int64, F.cuda())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment