Commit f7b4c93d authored by sangwzh's avatar sangwzh
Browse files

fix bug for uva when getting the device pointer

parent 50c27a8e
...@@ -112,7 +112,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) { ...@@ -112,7 +112,15 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
{return_len, original_feature_size}, torch::TensorOptions() {return_len, original_feature_size}, torch::TensorOptions()
.dtype(input.dtype()) .dtype(input.dtype())
.device(c10::DeviceType::CUDA)); .device(c10::DeviceType::CUDA));
DType* input_ptr = reinterpret_cast<DType*>(input.data_ptr()); DType* input_ptr = nullptr;
if(input.is_pinned())
{
CUDA_CALL(hipHostGetDevicePointer((void**)&input_ptr, input.data_ptr(), 0));
}
else{
input_ptr= reinterpret_cast<DType*>(input.data_ptr());
}
DType* ret_ptr = reinterpret_cast<DType*>(ret.data_ptr()); DType* ret_ptr = reinterpret_cast<DType*>(ret.data_ptr());
// Sort the index to improve the memory access pattern. // Sort the index to improve the memory access pattern.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment