Unverified Commit 2613f7f0 authored by David Min's avatar David Min Committed by GitHub
Browse files

[Performance] Cacheline-aligned access for UnifiedTensor (#3254)



* Add pytorch-direct version

* remove

* add documentation for UnifiedTensor

* Revert "add documentation for UnifiedTensor"

This reverts commit 63ba42644d4aba197c1cb4ea4b85fa1bc43b8849.

* alignment fix for UnifiedTensor access

* fix linting issue
Co-authored-by: default avatarshhssdm <shhssdm@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 76af2a2e
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../array_index_select.cuh"
#include "./array_index_select_uvm.cuh" #include "./array_index_select_uvm.cuh"
#include "../utils.h" #include "../utils.h"
...@@ -48,8 +49,13 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) { ...@@ -48,8 +49,13 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
block.y *= 2; block.y *= 2;
} }
const dim3 grid((len+block.y-1)/block.y); const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream, if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
array_data, num_feat, idx_data, len, ret_data); CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data);
} else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data);
}
} }
return ret; return ret;
} }
......
...@@ -7,23 +7,14 @@ ...@@ -7,23 +7,14 @@
#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_ #ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_ #define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_
#define CACHE_LINE_SIZE 128
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace impl { namespace impl {
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(const DType* array, const IdType* index, __global__ void IndexSelectMultiKernelAligned(
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
const DType* const array, const DType* const array,
const int64_t num_feat, const int64_t num_feat,
const IdType* const index, const IdType* const index,
...@@ -36,8 +27,12 @@ __global__ void IndexSelectMultiKernel( ...@@ -36,8 +27,12 @@ __global__ void IndexSelectMultiKernel(
while (out_row < length) { while (out_row < length) {
int64_t col = threadIdx.x; int64_t col = threadIdx.x;
const int64_t in_row = index[out_row]; const int64_t in_row = index[out_row];
const int64_t idx_offset =
((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType);
col = col - idx_offset;
while (col < num_feat) { while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col]; if (col >= 0)
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x; col += blockDim.x;
} }
out_row += stride; out_row += stride;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment