Commit 73427720 authored by traveller59's avatar traveller59
Browse files

fix #45 release requirement of kernel size

parent 10db9b67
...@@ -147,8 +147,7 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut, ...@@ -147,8 +147,7 @@ assignIndicePairsKernel(tv::TensorView<Index> indicesOut,
} }
} }
template <typename Index, typename IndexGrid, unsigned NDim, template <typename Index, typename IndexGrid, unsigned NDim>
int KernelMaxVolume = 256>
__global__ void __global__ void
prepareSubMGridKernel(tv::TensorView<const Index> indicesIn, prepareSubMGridKernel(tv::TensorView<const Index> indicesIn,
tv::TensorView<IndexGrid> gridsOut, tv::TensorView<IndexGrid> gridsOut,
......
...@@ -47,7 +47,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize, ...@@ -47,7 +47,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize,
for (int i = 1; i < kernelSize.size(); ++i) { for (int i = 1; i < kernelSize.size(); ++i) {
kernelVolume *= kernelSize[i]; kernelVolume *= kernelSize[i];
} }
TV_ASSERT_RT_ERR(kernelVolume <= 256, "error"); TV_ASSERT_RT_ERR(kernelVolume <= 4096, "error");
auto outputVolume = outSpatialShape[0]; auto outputVolume = outSpatialShape[0];
for (int i = 1; i < outSpatialShape.size(); ++i) { for (int i = 1; i < outSpatialShape.size(); ++i) {
outputVolume *= outSpatialShape[i]; outputVolume *= outSpatialShape[i];
...@@ -159,7 +159,7 @@ getIndicePairPreGrid(torch::Tensor indices, torch::Tensor gridOut, int64_t batch ...@@ -159,7 +159,7 @@ getIndicePairPreGrid(torch::Tensor indices, torch::Tensor gridOut, int64_t batch
for (int i = 1; i < kernelSize.size(); ++i) { for (int i = 1; i < kernelSize.size(); ++i) {
kernelVolume *= kernelSize[i]; kernelVolume *= kernelSize[i];
} }
TV_ASSERT_RT_ERR(kernelVolume <= 256, "error"); TV_ASSERT_RT_ERR(kernelVolume <= 4096, "error");
auto outputVolume = outSpatialShape[0]; auto outputVolume = outSpatialShape[0];
for (int i = 1; i < outSpatialShape.size(); ++i) { for (int i = 1; i < outSpatialShape.size(); ++i) {
outputVolume *= outSpatialShape[i]; outputVolume *= outSpatialShape[i];
......
...@@ -102,7 +102,7 @@ void sstream_print(SStream &ss, T val, TArgs... args) { ...@@ -102,7 +102,7 @@ void sstream_print(SStream &ss, T val, TArgs... args) {
struct GPU { struct GPU {
GPU(cudaStream_t s = 0) : mStream(s) {} GPU(cudaStream_t s = 0) : mStream(s) {}
cudaStream_t stream() const { return mStream; } virtual cudaStream_t getStream() const { return mStream; }
cudaStream_t mStream = 0; cudaStream_t mStream = 0;
}; };
struct CPU {}; struct CPU {};
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
namespace tv { namespace tv {
struct TorchGPU: public tv::GPU { struct TorchGPU: public tv::GPU {
TorchGPU(){ virtual cudaStream_t getStream() const override {
mStream = at::cuda::getCurrentCUDAStream(); return at::cuda::getCurrentCUDAStream();
} }
}; };
...@@ -48,7 +48,11 @@ template <typename T> void check_torch_dtype(const torch::Tensor &tensor) { ...@@ -48,7 +48,11 @@ template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
TV_ASSERT_RT_ERR(val, "error"); TV_ASSERT_RT_ERR(val, "error");
break; break;
} }
case at::ScalarType::Long: {
auto val = std::is_same<std::remove_const_t<T>, long>::value;
TV_ASSERT_RT_ERR(val, "error");
break;
}
default: default:
TV_ASSERT_RT_ERR(false, "error"); TV_ASSERT_RT_ERR(false, "error");
} }
......
...@@ -45,15 +45,15 @@ struct CreateConvIndicePairFunctorP1<tv::GPU, Index, IndexGrid, NDim> { ...@@ -45,15 +45,15 @@ struct CreateConvIndicePairFunctorP1<tv::GPU, Index, IndexGrid, NDim> {
return 0; return 0;
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
if (transpose) if (transpose)
prepareDeConvIndicePairsKernel<Index, IndexGrid, NDim, 256> prepareDeConvIndicePairsKernel<Index, IndexGrid, NDim, 4096>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesIn, indicesOut, gridsOut, indicePairs, d.getStream()>>>(indicesIn, indicesOut, gridsOut, indicePairs,
indiceNum, indicePairUnique, kernelSize, stride, indiceNum, indicePairUnique, kernelSize, stride,
padding, dilation, outSpatialShape); padding, dilation, outSpatialShape);
else else
prepareIndicePairsKernel<Index, IndexGrid, NDim, 256> prepareIndicePairsKernel<Index, IndexGrid, NDim, 4096>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesIn, indicesOut, gridsOut, indicePairs, d.getStream()>>>(indicesIn, indicesOut, gridsOut, indicePairs,
indiceNum, indicePairUnique, kernelSize, stride, indiceNum, indicePairUnique, kernelSize, stride,
padding, dilation, outSpatialShape); padding, dilation, outSpatialShape);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -80,18 +80,18 @@ struct CreateConvIndicePairFunctorP2<tv::GPU, Index, IndexGrid, NDim> { ...@@ -80,18 +80,18 @@ struct CreateConvIndicePairFunctorP2<tv::GPU, Index, IndexGrid, NDim> {
Index numAct = indicePairUnique.dim(0) - 1; Index numAct = indicePairUnique.dim(0) - 1;
assignGridAndIndiceOutKernel<Index, IndexGrid, NDim> assignGridAndIndiceOutKernel<Index, IndexGrid, NDim>
<<<tv::launch::getBlocks(numAct), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numAct), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesOut, gridsOut, numAct, indicePairs, d.getStream()>>>(indicesOut, gridsOut, numAct, indicePairs,
indicePairUnique, outSpatialShape, batchSize); indicePairUnique, outSpatialShape, batchSize);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
assignIndicePairsKernel<Index, IndexGrid, NDim> assignIndicePairsKernel<Index, IndexGrid, NDim>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesOut, gridsOut, numActIn, indicePairs, d.getStream()>>>(indicesOut, gridsOut, numActIn, indicePairs,
indicePairUnique, outSpatialShape); indicePairUnique, outSpatialShape);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
if (resetGrid) { if (resetGrid) {
resetGridKernel<Index, IndexGrid, NDim> resetGridKernel<Index, IndexGrid, NDim>
<<<tv::launch::getBlocks(numAct), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numAct), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicePairUnique.data(), gridsOut, numAct); d.getStream()>>>(indicePairUnique.data(), gridsOut, numAct);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
return numAct; return numAct;
...@@ -116,18 +116,18 @@ struct CreateSubMIndicePairFunctor<tv::GPU, Index, IndexGrid, NDim> { ...@@ -116,18 +116,18 @@ struct CreateSubMIndicePairFunctor<tv::GPU, Index, IndexGrid, NDim> {
// auto timer = spconv::CudaContextTimer<>(); // auto timer = spconv::CudaContextTimer<>();
prepareSubMGridKernel<Index, IndexGrid, NDim> prepareSubMGridKernel<Index, IndexGrid, NDim>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesIn, gridsOut, outSpatialShape); d.getStream()>>>(indicesIn, gridsOut, outSpatialShape);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
getSubMIndicePairsKernel<Index, IndexGrid, NDim> getSubMIndicePairsKernel<Index, IndexGrid, NDim, 4096>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesIn, gridsOut, indicePairs, indiceNum, d.getStream()>>>(indicesIn, gridsOut, indicePairs, indiceNum,
kernelSize, stride, padding, dilation, outSpatialShape); kernelSize, stride, padding, dilation, outSpatialShape);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
// std::cout << "subm gene time " << timer.report() / 1000.0 << std::endl; // std::cout << "subm gene time " << timer.report() / 1000.0 << std::endl;
if (resetGrid) { if (resetGrid) {
resetGridSubMKernel<Index, IndexGrid, NDim> resetGridSubMKernel<Index, IndexGrid, NDim>
<<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0, <<<tv::launch::getBlocks(numActIn), tv::launch::CUDA_NUM_THREADS, 0,
d.stream()>>>(indicesIn.data(), gridsOut, outSpatialShape, numActIn); d.getStream()>>>(indicesIn.data(), gridsOut, outSpatialShape, numActIn);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
return numActIn; return numActIn;
......
...@@ -329,7 +329,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -329,7 +329,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
maxPoolFwdVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> maxPoolFwdVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP), <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.stream()>>>(outFeatures.data(), inFeatures.data(), d.getStream()>>>(outFeatures.data(), inFeatures.data(),
indices.subview(0).data(), indices.subview(0).data(),
indices.subview(1).data(), numHotBlock, indices.subview(1).data(), numHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
...@@ -339,7 +339,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -339,7 +339,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolFwdGenericKernel<T, Index, int(NumTLP), NumILP> maxPoolFwdGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.stream()>>>(outFeatures.data(), inFeatures.data(), 0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
indices.subview(0).data() + numHotBlock, indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, indices.subview(1).data() + numHotBlock,
size - numHotBlock, numPlanes); size - numHotBlock, numPlanes);
...@@ -357,7 +357,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -357,7 +357,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
if (numHotBlock >= NumTLP) { if (numHotBlock >= NumTLP) {
maxPoolFwdGenericBlockKernel<T, Index, NumTLP, NumILP> maxPoolFwdGenericBlockKernel<T, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)), <<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), outFeatures.data(), inFeatures.data(),
indices.subview(0).data(), indices.subview(1).data(), indices.subview(0).data(), indices.subview(1).data(),
numHotBlock, numPlanes); numHotBlock, numPlanes);
...@@ -367,7 +367,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> { ...@@ -367,7 +367,7 @@ struct SparseMaxPoolForwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolFwdGenericKernel<T, Index, NumTLP, NumILP> maxPoolFwdGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)), <<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), outFeatures.data(), inFeatures.data(),
indices.subview(0).data() + numHotBlock, indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, size - numHotBlock, indices.subview(1).data() + numHotBlock, size - numHotBlock,
...@@ -403,7 +403,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -403,7 +403,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
maxPoolBwdVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> maxPoolBwdVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP), <<<dim3(std::min(size / NumTLP, 512), numPlanes / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.stream()>>>(outFeatures.data(), inFeatures.data(), d.getStream()>>>(outFeatures.data(), inFeatures.data(),
dout.data(), din.data(), dout.data(), din.data(),
indices.subview(0).data(), indices.subview(0).data(),
indices.subview(1).data(), numHotBlock, indices.subview(1).data(), numHotBlock,
...@@ -414,7 +414,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -414,7 +414,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolBwdGenericKernel<T, Index, int(NumTLP), NumILP> maxPoolBwdGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.stream()>>>(outFeatures.data(), inFeatures.data(), 0, d.getStream()>>>(outFeatures.data(), inFeatures.data(),
dout.data(), din.data(), dout.data(), din.data(),
indices.subview(0).data() + numHotBlock, indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, indices.subview(1).data() + numHotBlock,
...@@ -433,7 +433,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -433,7 +433,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (numHotBlock >= NumTLP) { if (numHotBlock >= NumTLP) {
maxPoolBwdGenericBlockKernel<T, Index, NumTLP, NumILP> maxPoolBwdGenericBlockKernel<T, Index, NumTLP, NumILP>
<<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)), <<<dim3(size / NumTLP, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), dout.data(), din.data(), outFeatures.data(), inFeatures.data(), dout.data(), din.data(),
indices.subview(0).data(), indices.subview(1).data(), indices.subview(0).data(), indices.subview(1).data(),
numHotBlock, numPlanes); numHotBlock, numPlanes);
...@@ -443,7 +443,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> { ...@@ -443,7 +443,7 @@ struct SparseMaxPoolBackwardFunctor<tv::GPU, T, Index> {
if (size > numHotBlock) { if (size > numHotBlock) {
maxPoolBwdGenericKernel<T, Index, NumTLP, NumILP> maxPoolBwdGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)), <<<dim3(1, tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), inFeatures.data(), dout.data(), din.data(), outFeatures.data(), inFeatures.data(), dout.data(), din.data(),
indices.subview(0).data() + numHotBlock, indices.subview(0).data() + numHotBlock,
indices.subview(1).data() + numHotBlock, size - numHotBlock, indices.subview(1).data() + numHotBlock, size - numHotBlock,
......
...@@ -50,7 +50,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> { ...@@ -50,7 +50,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> {
gatherVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> gatherVecBlockKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(numPlanes / NumTLP, size / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.stream()>>>(buffer.data(), features.data(), indices.data(), d.getStream()>>>(buffer.data(), features.data(), indices.data(),
nHotBlock, numPlanes / vecloadFactor); nHotBlock, numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -59,7 +59,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> { ...@@ -59,7 +59,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> {
gatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t> gatherVecKernel<T, Index, int(NumTLP), NumILP, vecload_type_t>
<<<dim3(1, numPlanes / NumTLP), <<<dim3(1, numPlanes / NumTLP),
dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0, dim3(NumTLP / NumILP, NumTLP / vecloadFactor), 0,
d.stream()>>>(buffer.data() + nHotBlock * numPlanes, d.getStream()>>>(buffer.data() + nHotBlock * numPlanes,
features.data(), indices.data() + nHotBlock, features.data(), indices.data() + nHotBlock,
size - nHotBlock, numPlanes / vecloadFactor); size - nHotBlock, numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -75,7 +75,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> { ...@@ -75,7 +75,7 @@ struct SparseGatherFunctor<tv::GPU, T, Index> {
gatherGenericKernel<T, Index, NumTLP, NumILP> gatherGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(tv::launch::DivUp(size, NumTLP), <<<dim3(tv::launch::DivUp(size, NumTLP),
tv::launch::DivUp(numPlanes, NumTLP)), tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
buffer.data(), features.data(), indices.data(), size, numPlanes); buffer.data(), features.data(), indices.data(), size, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
} }
...@@ -107,7 +107,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> { ...@@ -107,7 +107,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> {
vecload_type_t> vecload_type_t>
<<<dim3(numPlanes / NumTLP, size / NumTLP), <<<dim3(numPlanes / NumTLP, size / NumTLP),
dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0, dim3(NumTLP / vecloadFactor, NumTLP / NumILP), 0,
d.stream()>>>(outFeatures.data(), buffer.data(), d.getStream()>>>(outFeatures.data(), buffer.data(),
indices.data(), nHotBlock, indices.data(), nHotBlock,
numPlanes / vecloadFactor); numPlanes / vecloadFactor);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -115,7 +115,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> { ...@@ -115,7 +115,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> {
if (size - nHotBlock > 0) { if (size - nHotBlock > 0) {
scatterAddGenericKernel<T, Index, int(NumTLP), NumILP> scatterAddGenericKernel<T, Index, int(NumTLP), NumILP>
<<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP), <<<dim3(1, numPlanes / NumTLP), dim3(NumTLP / NumILP, NumTLP),
0, d.stream()>>>( 0, d.getStream()>>>(
outFeatures.data(), buffer.data() + nHotBlock * numPlanes, outFeatures.data(), buffer.data() + nHotBlock * numPlanes,
indices.data() + nHotBlock, size - nHotBlock, numPlanes); indices.data() + nHotBlock, size - nHotBlock, numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
...@@ -130,7 +130,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> { ...@@ -130,7 +130,7 @@ struct SparseScatterAddFunctor<tv::GPU, T, Index> {
scatterAddGenericKernel<T, Index, NumTLP, NumILP> scatterAddGenericKernel<T, Index, NumTLP, NumILP>
<<<dim3(tv::launch::DivUp(size, NumTLP), <<<dim3(tv::launch::DivUp(size, NumTLP),
tv::launch::DivUp(numPlanes, NumTLP)), tv::launch::DivUp(numPlanes, NumTLP)),
dim3(NumTLP / NumILP, NumTLP), 0, d.stream()>>>( dim3(NumTLP / NumILP, NumTLP), 0, d.getStream()>>>(
outFeatures.data(), buffer.data(), indices.data(), size, outFeatures.data(), buffer.data(), indices.data(), size,
numPlanes); numPlanes);
TV_CHECK_CUDA_ERR(); TV_CHECK_CUDA_ERR();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment