Unverified Commit 4f5c3aa2 authored by David Min's avatar David Min Committed by GitHub
Browse files

[Bugfix] Add UVM specialized IndexSelect kernels which perform boundary checks (#3293)



* Add pytorch-direct version

* remove

* add documentation for UnifiedTensor

* Revert "add documentation for UnifiedTensor"

This reverts commit 63ba42644d4aba197c1cb4ea4b85fa1bc43b8849.

* add boundary check for UVM IndexSelect

* relocate boundary check index kernels to cuda

* fix function name

* fix indexkernel in nccl api

* fix argument ordering

* simplify code

* Add a comment for the uvm version
Co-authored-by: default avatarshhssdm <shhssdm@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 04ed6126
...@@ -36,7 +36,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -36,7 +36,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream,
array_data, idx_data, len, ret_data); array_data, idx_data, len, arr_len, ret_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2*num_feat) {
...@@ -45,7 +45,7 @@ NDArray IndexSelect(NDArray array, IdArray index) { ...@@ -45,7 +45,7 @@ NDArray IndexSelect(NDArray array, IdArray index) {
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream,
array_data, num_feat, idx_data, len, ret_data); array_data, num_feat, idx_data, len, arr_len, ret_data);
} }
return ret; return ret;
} }
......
...@@ -12,11 +12,15 @@ namespace aten { ...@@ -12,11 +12,15 @@ namespace aten {
namespace impl { namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(const DType* array, const IdType* index, __global__ void IndexSelectSingleKernel(const DType* array,
int64_t length, DType* out) { const IdType* index,
const int64_t length,
const int64_t arr_len,
DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
assert(index[tx] >= 0 && index[tx] < arr_len);
out[tx] = array[index[tx]]; out[tx] = array[index[tx]];
tx += stride_x; tx += stride_x;
} }
...@@ -24,10 +28,11 @@ __global__ void IndexSelectSingleKernel(const DType* array, const IdType* index, ...@@ -24,10 +28,11 @@ __global__ void IndexSelectSingleKernel(const DType* array, const IdType* index,
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel( __global__ void IndexSelectMultiKernel(
const DType* const array, const DType* const array,
const int64_t num_feat, const int64_t num_feat,
const IdType* const index, const IdType* const index,
const int64_t length, const int64_t length,
const int64_t arr_len,
DType* const out) { DType* const out) {
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y; int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
...@@ -36,6 +41,7 @@ __global__ void IndexSelectMultiKernel( ...@@ -36,6 +41,7 @@ __global__ void IndexSelectMultiKernel(
while (out_row < length) { while (out_row < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row];
assert(in_row >= 0 && in_row < arr_len);
while (col < num_feat) { while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col]; out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x; col += blockDim.x;
......
...@@ -40,8 +40,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -40,8 +40,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
if (num_feat == 1) { if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len); const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream, CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0,
array_data, idx_data, len, ret_data); thr_entry->stream, array_data, idx_data, len, arr_len, ret_data);
} else { } else {
dim3 block(256, 1); dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) { while (static_cast<int64_t>(block.x) >= 2*num_feat) {
...@@ -51,10 +51,12 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -51,10 +51,12 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) { if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data); thr_entry->stream, array_data, num_feat, idx_data,
len, arr_len, ret_data);
} else { } else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0, CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data); thr_entry->stream, array_data, num_feat, idx_data,
len, arr_len, ret_data);
} }
} }
return ret; return ret;
......
...@@ -13,12 +13,17 @@ namespace dgl { ...@@ -13,12 +13,17 @@ namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
/* This is a cross-device access version of IndexSelectMultiKernel.
* Since the memory access over PCIe is more sensitive to the
* data access aligment (cacheline), we need a separate version here.
*/
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernelAligned( __global__ void IndexSelectMultiKernelAligned(
const DType* const array, const DType* const array,
const int64_t num_feat, const int64_t num_feat,
const IdType* const index, const IdType* const index,
const int64_t length, const int64_t length,
const int64_t arr_len,
DType* const out) { DType* const out) {
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y; int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
...@@ -27,6 +32,7 @@ __global__ void IndexSelectMultiKernelAligned( ...@@ -27,6 +32,7 @@ __global__ void IndexSelectMultiKernelAligned(
while (out_row < length) { while (out_row < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row];
assert(in_row >= 0 && in_row < arr_len);
const int64_t idx_offset = const int64_t idx_offset =
((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType); ((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType);
col = col - idx_offset; col = col - idx_offset;
......
...@@ -336,6 +336,7 @@ NDArray SparsePull( ...@@ -336,6 +336,7 @@ NDArray SparsePull(
static_cast<const IdType*>(req_idx->data), static_cast<const IdType*>(req_idx->data),
perm, perm,
num_in, num_in,
req_idx->shape[0],
send_idx.get()); send_idx.get());
CUDA_CALL(cudaGetLastError()); CUDA_CALL(cudaGetLastError());
} }
...@@ -443,6 +444,7 @@ NDArray SparsePull( ...@@ -443,6 +444,7 @@ NDArray SparsePull(
num_feat, num_feat,
static_cast<IdType*>(recv_idx->data), static_cast<IdType*>(recv_idx->data),
response_prefix_host.back(), response_prefix_host.back(),
local_tensor->shape[0],
filled_response_value.get()); filled_response_value.get());
CUDA_CALL(cudaGetLastError()); CUDA_CALL(cudaGetLastError());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment