Unverified Commit 00548245 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

Wrap cub in it's own namespace and remove dependency on Thrust (#587)



* Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust

* update readme

* format

* Update install_cuda.sh
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 2007d181
...@@ -2,14 +2,13 @@ ...@@ -2,14 +2,13 @@
# command copy-pasted from https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=debnetwork # command copy-pasted from https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=debnetwork
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-ubuntu2004.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600 sudo mv cuda-ubuntu2004.pin /etc/apt/preferences.d/cuda-repository-pin-600
sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub
sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /" sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/ /"
sudo apt-get update sudo apt-get update
# from https://github.com/ptheywood/cuda-cmake-github-actions/blob/master/scripts/actions/install_cuda_ubuntu.sh # from https://github.com/ptheywood/cuda-cmake-github-actions/blob/master/scripts/actions/install_cuda_ubuntu.sh
sudo apt-get -y install cuda-command-line-tools-11-0 cuda-libraries-dev-11-0 sudo apt-get -y install cuda-command-line-tools-11-2 cuda-libraries-dev-11-2
export CUDA_HOME=/usr/local/cuda-11.0 export CUDA_HOME=/usr/local/cuda-11.2
export PATH="$CUDA_HOME/bin:$PATH" export PATH="$CUDA_HOME/bin:$PATH"
nvcc -V nvcc -V
...@@ -8,7 +8,6 @@ If you use conda, you could install it by ...@@ -8,7 +8,6 @@ If you use conda, you could install it by
``` ```
conda install pytorch torchvision torchaudio cudatoolkit={YOUR_CUDA_VERSION} -c pytorch-nightly conda install pytorch torchvision torchaudio cudatoolkit={YOUR_CUDA_VERSION} -c pytorch-nightly
``` ```
Note that [CUDA 11](https://github.com/aiqm/torchani/issues/549) is still not supported yet.
## Install ## Install
In most cases, if `gcc` and `cuda` environment are well configured, runing the following command at `torchani` directory will install torchani and cuaev together. In most cases, if `gcc` and `cuda` environment are well configured, runing the following command at `torchani` directory will install torchani and cuaev together.
...@@ -51,15 +50,13 @@ pip install -e . && pip install -v -e . --global-option="--cuaev" ...@@ -51,15 +50,13 @@ pip install -e . && pip install -v -e . --global-option="--cuaev"
```bash ```bash
srun -p gpu --ntasks=1 --cpus-per-task=2 --gpus=geforce:1 --time=02:00:00 --mem=10gb --pty -u bash -i srun -p gpu --ntasks=1 --cpus-per-task=2 --gpus=geforce:1 --time=02:00:00 --mem=10gb --pty -u bash -i
module load cuda/10.0.130 gcc/7.3.0 git # create env if necessary
conda remove --name cuaev --all -y && conda create -n cuaev python=3.8 -y conda create -n cuaev python=3.8
conda activate cuaev conda activate cuaev
# install compiled torch-cu100 because pytorch droped official build for cuda 10.0 # modules
. /home/jinzexue/pytorch/loadmodule # note that there is a space after . module load cuda/11.1.0 gcc/7.3.0 git
. /home/jinzexue/pytorch/install_deps # pytorch
pip install $(realpath /home/jinzexue/pytorch/dist/torch-nightly-cu100.whl) conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-nightly -c conda-forge
# check if pytorch is working, should print available's gpu infomations
python /home/jinzexue/pytorch/testcuda/testcuda.py
# install torchani # install torchani
git clone https://github.com/aiqm/torchani.git git clone https://github.com/aiqm/torchani.git
cd torchani cd torchani
......
#include <aev.h> #include <aev.h>
#include <thrust/equal.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh> #include <cuaev_cub.cuh>
#include <vector> #include <vector>
#include <ATen/Context.h> #include <ATen/Context.h>
...@@ -728,105 +727,6 @@ __global__ void cuRadialAEVs_backward_or_doublebackward( ...@@ -728,105 +727,6 @@ __global__ void cuRadialAEVs_backward_or_doublebackward(
} }
} }
template <typename DataT>
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
}
template <typename DataT, typename IndexT>
int cubEncode(
const DataT* d_in,
DataT* d_unique_out,
IndexT* d_counts_out,
int num_items,
int* d_num_runs_out,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run encoding
cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT, typename LambdaOpT>
int cubDeviceSelect(
const DataT* d_in,
DataT* d_out,
int num_items,
int* d_num_selected_out,
LambdaOpT select_op,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceSelect::If(d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run selection
cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT>
DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run min-reduction
cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return maxVal;
}
// NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1 // NOTE: assumes size of EtaA_t = Zeta_t = EtaR_t = 1
Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const AEVScalarParams& aev_params) { Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const AEVScalarParams& aev_params) {
TORCH_CHECK( TORCH_CHECK(
...@@ -850,21 +750,15 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -850,21 +750,15 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
} }
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto thrust_allocator = THCThrustAllocator(at::globalContext().lazyInitCUDA());
auto policy = thrust::cuda::par(thrust_allocator).on(stream);
auto& allocator = *c10::cuda::CUDACachingAllocator::get(); auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// buffer to store all the pairwise distance (Rij) // buffer to store all the pairwise distance (Rij)
auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol; auto total_natom_pairs = n_molecules * max_natoms_per_mol * max_natoms_per_mol;
auto d_options = torch::dtype(torch::kUInt8).device(coordinates_t.device()); auto d_options = torch::dtype(torch::kFloat32).device(coordinates_t.device());
Tensor tensor_Rij = torch::empty(sizeof(PairDist<float>) * total_natom_pairs, d_options); float inf = std::numeric_limits<float>::infinity();
Tensor tensor_Rij = torch::full(sizeof(PairDist<float>) / sizeof(float) * total_natom_pairs, inf, d_options);
PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr(); PairDist<float>* d_Rij = (PairDist<float>*)tensor_Rij.data_ptr();
// init all Rij to inf
PairDist<float> init;
init.Rij = std::numeric_limits<float>::infinity();
thrust::fill(policy, d_Rij, d_Rij + total_natom_pairs, init);
// buffer to store all the pairwise distance that is needed for Radial AEV // buffer to store all the pairwise distance that is needed for Radial AEV
// computation // computation
Tensor tensor_radialRij = torch::empty(sizeof(PairDist<float>) * total_natom_pairs, d_options); Tensor tensor_radialRij = torch::empty(sizeof(PairDist<float>) * total_natom_pairs, d_options);
...@@ -986,21 +880,22 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const ...@@ -986,21 +880,22 @@ Result cuaev_forward(const Tensor& coordinates_t, const Tensor& species_t, const
angular_length_aligned, angular_length_aligned,
ncenter_atoms); ncenter_atoms);
return {aev_t, return {
tensor_Rij, aev_t,
tensor_radialRij, tensor_Rij,
tensor_angularRij, tensor_radialRij,
total_natom_pairs, tensor_angularRij,
nRadialRij, total_natom_pairs,
nAngularRij, nRadialRij,
tensor_centralAtom, nAngularRij,
tensor_numPairsPerCenterAtom, tensor_centralAtom,
tensor_centerAtomStartIdx, tensor_numPairsPerCenterAtom,
maxnbrs_per_atom_aligned, tensor_centerAtomStartIdx,
angular_length_aligned, maxnbrs_per_atom_aligned,
ncenter_atoms, angular_length_aligned,
coordinates_t, ncenter_atoms,
species_t}; coordinates_t,
species_t};
} }
} }
......
#pragma once
// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX // undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace cuaev {
#define CUB_NS_POSTFIX }
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
template <typename DataT>
void cubScan(const DataT* d_in, DataT* d_out, int num_items, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run exclusive prefix sum
cuaev::cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
}
template <typename DataT, typename IndexT>
int cubEncode(
const DataT* d_in,
DataT* d_unique_out,
IndexT* d_counts_out,
int num_items,
int* d_num_runs_out,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run encoding
cuaev::cub::DeviceRunLengthEncode::Encode(
d_temp_storage, temp_storage_bytes, d_in, d_unique_out, d_counts_out, d_num_runs_out, num_items, stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_runs_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT, typename LambdaOpT>
int cubDeviceSelect(
const DataT* d_in,
DataT* d_out,
int num_items,
int* d_num_selected_out,
LambdaOpT select_op,
cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run selection
cuaev::cub::DeviceSelect::If(
d_temp_storage, temp_storage_bytes, d_in, d_out, d_num_selected_out, num_items, select_op, stream);
int num_selected = 0;
cudaMemcpyAsync(&num_selected, d_num_selected_out, sizeof(int), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return num_selected;
}
template <typename DataT>
DataT cubMax(const DataT* d_in, int num_items, DataT* d_out, cudaStream_t stream) {
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
// Allocate temporary storage
auto buffer_tmp = allocator.allocate(temp_storage_bytes);
d_temp_storage = buffer_tmp.get();
// Run min-reduction
cuaev::cub::DeviceReduce::Max(d_temp_storage, temp_storage_bytes, d_in, d_out, num_items, stream);
int maxVal = 0;
cudaMemcpyAsync(&maxVal, d_out, sizeof(DataT), cudaMemcpyDefault, stream);
cudaStreamSynchronize(stream);
return maxVal;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment