Unverified Commit c94ad7a0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Fix cuaev build on CUDA 11.5 and latest torch (#603)



* Fix cuaev build on CUDA 11.5 and latest torch

* Update aev.cu
Co-authored-by: default avatarJinze (Richard) Xue <yueyericardo@gmail.com>
parent 258e6c36
...@@ -3,12 +3,10 @@ ...@@ -3,12 +3,10 @@
#include <cuaev_cub.cuh> #include <cuaev_cub.cuh>
#include <ATen/Context.h> #include <ATen/Context.h>
#include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <THC/THCThrustAllocator.cuh>
#include <vector> #include <vector>
#define PI 3.141592653589793 #define PI 3.141592653589793
......
#pragma once #pragma once
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX // undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace cuaev {
#define CUB_NS_POSTFIX }
#include <cub/cub.cuh> #include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
template <typename DataT> template <typename DataT>
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) { void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
...@@ -17,14 +9,14 @@ void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream ...@@ -17,14 +9,14 @@ void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum // Run exclusive prefix sum
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
} }
template <typename DataT, typename IndexT> template <typename DataT, typename IndexT>
...@@ -40,7 +32,7 @@ int cubEncode( ...@@ -40,7 +32,7 @@ int cubEncode(
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cuaev::cub::DeviceRunLengthEncode::Encode( cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream); d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
// Allocate temporary storage // Allocate temporary storage
...@@ -48,7 +40,7 @@ int cubEncode( ...@@ -48,7 +40,7 @@ int cubEncode(
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run encoding // Run encoding
cuaev::cub::DeviceRunLengthEncode::Encode( cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream); d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
int num_selected = 0; int num_selected = 0;
...@@ -70,7 +62,7 @@ int cubDeviceSelect( ...@@ -70,7 +62,7 @@ int cubDeviceSelect(
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cuaev::cub::DeviceSelect::If( cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op); d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage // Allocate temporary storage
...@@ -78,7 +70,7 @@ int cubDeviceSelect( ...@@ -78,7 +70,7 @@ int cubDeviceSelect(
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run selection // Run selection
cuaev::cub::DeviceSelect::If( cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream); d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream);
int num_selected = 0; int num_selected = 0;
...@@ -94,14 +86,14 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream ...@@ -94,14 +86,14 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL; void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage // Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes); auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get(); d_temp_storage = buffer_tmp.get();
// Run min-reduction // Run min-reduction
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream); cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
int maxVal = 0; int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream); cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment