Commit fad30002 authored by Yan Yan's avatar Yan Yan
Browse files

fix a strange bug of cuda 11

parent f22dd9ae
......@@ -105,6 +105,7 @@ int create_conv_indice_pair_p2_cuda(
auto kernelVolume = indiceNum.size(0);
if (numActIn == 0)
return 0;
bool failed = false;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
using Index = TV_DECLTYPE(IndexValue);
using IndexGrid = int32_t;
......@@ -131,7 +132,8 @@ int create_conv_indice_pair_p2_cuda(
cudaFree(d_values);
TV_CHECK_CUDA_ERR_V2("cudaFree failed");
if (!res) {
return -1; // use -1 to tell outside use CPU implementation
failed = true;
return;
}
assignIndiceOutKernel<Index, NDim>
<<<tv::cuda::getBlocks(numAct), tv::cuda::CUDA_NUM_THREADS, 0,
......@@ -190,6 +192,9 @@ int create_conv_indice_pair_p2_cuda(
}
});
});
if (failed){
return -1;
}
return numAct;
}
......@@ -211,6 +216,8 @@ int create_submconv_indice_pair_cuda(
auto kernelVolume = indiceNum.size(0);
if (numActIn == 0)
return 0;
bool failed = false;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
using Index = TV_DECLTYPE(IndexValue);
using IndexGrid = int32_t;
......@@ -245,7 +252,8 @@ int create_submconv_indice_pair_cuda(
cudaFree(d_keyvalues);
TV_CHECK_CUDA_ERR_V2("cudaFree failed");
if (!res) {
return -1; // use -1 to tell outside use CPU implementation
failed = true;
return;
}
auto tableSize = table.get_table_size();
auto tableData = table.data();
......@@ -349,6 +357,10 @@ int create_submconv_indice_pair_cuda(
}
});
});
if (failed){
return -1;
}
return numActIn;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment