Unverified Commit c94ad7a0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Fix cuaev build on CUDA 11.5 and latest torch (#603)



* Fix cuaev build on CUDA 11.5 and latest torch

* Update aev.cu
Co-authored-by: default avatarJinze (Richard) Xue <yueyericardo@gmail.com>
parent 258e6c36
......@@ -3,12 +3,10 @@
#include <cuaev_cub.cuh>
#include <ATen/Context.h>
#include <THC/THC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <THC/THCThrustAllocator.cuh>
#include <vector>
#define PI 3.141592653589793
......
#pragma once
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX // undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace cuaev {
#define CUB_NS_POSTFIX }
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
template <typename DataT>
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
......@@ -17,14 +9,14 @@ void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
}
template <typename DataT, typename IndexT>
......@@ -40,7 +32,7 @@ int cubEncode(
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceRunLengthEncode::Encode(
cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
// Allocate temporary storage
......@@ -48,7 +40,7 @@ int cubEncode(
d_temp_storage = buffer_tmp.get();
// Run encoding
cuaev::cub::DeviceRunLengthEncode::Encode(
cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
int num_selected = 0;
......@@ -70,7 +62,7 @@ int cubDeviceSelect(
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceSelect::If(
cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage
......@@ -78,7 +70,7 @@ int cubDeviceSelect(
d_temp_storage = buffer_tmp.get();
// Run selection
cuaev::cub::DeviceSelect::If(
cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream);
int num_selected = 0;
......@@ -94,14 +86,14 @@ DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run min-reduction
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment