Unverified Commit 16e771c0 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fix that UVA cannot work on old GPUs (#4781)

* get device pointers

* change if condition to IsPinned
parent 7c788f53
......@@ -17,7 +17,10 @@ namespace impl {
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const DType* array_data = array.Ptr<DType>();
if (array.IsPinned()) {
CUDA_CALL(cudaHostGetDevicePointer(&array_data, array.Ptr<DType>(), 0));
}
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
......
......@@ -41,11 +41,24 @@ NDArray CSRGetData(
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data_data, csr.data.Ptr<IdType>(), 0));
}
}
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
indptr_data, indices_data, data_data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
row_stride, col_stride, rstlen,
return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
......
......@@ -253,14 +253,23 @@ COOMatrix _CSRRowWiseSamplingUniform(
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* const data = CSRHasData(mat) ?
static_cast<IdType*>(mat.data->data) : nullptr;
const IdType* in_ptr = mat.indptr.Ptr<IdType>();
const IdType* in_cols = mat.indices.Ptr<IdType>();
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
if (mat.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&in_ptr, mat.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data, mat.data.Ptr<IdType>(), 0));
}
}
// compute degree
IdType * out_deg = static_cast<IdType*>(
......
......@@ -502,15 +502,26 @@ COOMatrix _CSRRowWiseSampling(
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
const IdType * const in_ptr = static_cast<const IdType*>(mat.indptr->data);
const IdType * const in_cols = static_cast<const IdType*>(mat.indices->data);
IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
const IdType* const data = CSRHasData(mat) ?
static_cast<IdType*>(mat.data->data) : nullptr;
const FloatType* const prob_data = static_cast<const FloatType*>(prob->data);
const IdType* in_ptr = mat.indptr.Ptr<IdType>();
const IdType* in_cols = mat.indices.Ptr<IdType>();
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
const FloatType* prob_data = prob.Ptr<FloatType>();
if (mat.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&in_ptr, mat.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data, mat.data.Ptr<IdType>(), 0));
}
CUDA_CALL(cudaHostGetDevicePointer(
&prob_data, prob.Ptr<FloatType>(), 0));
}
// compute degree
// out_deg: the size of each row in the sampled matrix
......
......@@ -59,10 +59,18 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const int nt = dgl::cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
const IdType* data = nullptr;
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(dgl::cuda::_LinearSearchKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), data,
indptr_data, indices_data, data,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, rstlen,
static_cast<IdType*>(nullptr), static_cast<IdType>(-1), rst.Ptr<IdType>());
......@@ -150,8 +158,12 @@ template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto len = rows->shape[0];
const IdType* vid_data = static_cast<IdType*>(rows->data);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* vid_data = rows.Ptr<IdType>();
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
}
NDArray rst = NDArray::Empty({len}, rows->dtype, rows->ctx);
IdType* rst_data = static_cast<IdType*>(rst->data);
const int nt = dgl::cuda::FindNumThreads(len);
......@@ -255,16 +267,31 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
// Copy indices.
IdArray ret_indices = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data_data, csr.data.Ptr<IdType>(), 0));
}
}
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
indptr_data, indices_data,
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_indices.Ptr<IdType>());
// Copy data.
IdArray ret_data = NDArray::Empty({nnz}, csr.indptr->dtype, rows->ctx);
CUDA_KERNEL_CALL(_SegmentCopyKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), CSRHasData(csr)? csr.data.Ptr<IdType>() : nullptr,
indptr_data, data_data,
rows.Ptr<IdType>(), nnz, len,
ret_indptr.Ptr<IdType>(), ret_data.Ptr<IdType>());
return CSRMatrix(len, csr.num_cols,
......@@ -360,13 +387,22 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// Generate a 0-1 mask for matched (row, col) positions.
IdArray mask = Full(0, nnz, nbits, ctx);
const int nt = dgl::cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(),
indptr_data, indices_data,
row.Ptr<IdType>(), col.Ptr<IdType>(),
row_stride, col_stride, len,
mask.Ptr<IdType>());
......@@ -382,7 +418,7 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
const int nb2 = (idx->shape[0] + nt - 1) / nt;
CUDA_KERNEL_CALL(_SortedSearchKernel,
nb2, nt2, 0, stream,
csr.indptr.Ptr<IdType>(), csr.num_rows,
indptr_data, csr.num_rows,
idx.Ptr<IdType>(), idx->shape[0],
ret_row.Ptr<IdType>());
......@@ -471,11 +507,20 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
0, sizeof(IdType)*8, stream));
device->FreeWorkspace(ctx, workspace);
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
}
// Execute SegmentMaskColKernel
int nb = (nnz_csr + nt - 1) / nt;
CUDA_KERNEL_CALL(_SegmentMaskColKernel,
nb, nt, 0, stream,
csr.indptr.Ptr<IdType>(), csr.indices.Ptr<IdType>(), csr.num_rows, nnz_csr,
indptr_data, indices_data, csr.num_rows, nnz_csr,
ptr_sorted_cols, cols_size,
mask.Ptr<IdType>(), count.Ptr<IdType>());
......
......@@ -17,7 +17,6 @@ namespace impl {
template<typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
......@@ -25,6 +24,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
std::vector<int64_t> shape{len};
CHECK(array.IsPinned());
const DType* array_data = nullptr;
CUDA_CALL(cudaHostGetDevicePointer(&array_data, array.Ptr<DType>(), 0));
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
for (int d = 1; d < array->ndim; ++d) {
......@@ -76,7 +77,6 @@ template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template<typename DType, typename IdType>
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
DType* dest_data = static_cast<DType*>(dest->data);
const DType* source_data = static_cast<DType*>(source->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = dest->shape[0];
......@@ -85,6 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
std::vector<int64_t> shape{len};
CHECK(dest.IsPinned());
DType* dest_data = nullptr;
CUDA_CALL(cudaHostGetDevicePointer(&dest_data, dest.Ptr<DType>(), 0));
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
CHECK_EQ(source->ctx.device_type, kDGLCUDA);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment