Commit 65058287 authored by Max Rietmann's avatar Max Rietmann
Browse files

Merge formatting changes

parents c46b6925 76836abf
...@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e: ...@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def get_compile_args(module_name): def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build""" """If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1' debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
nvcc_extra_flags = []
if profile_mode:
nvcc_extra_flags.append("-lineinfo")
if debug_mode: if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags") print(f"WARNING: Compiling {module_name} with debugging flags")
return { return {
'cxx': ['-g', '-O0', '-Wall'], 'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0'] 'nvcc': ['-g', '-G', '-O0'] + nvcc_extra_flags
} }
else: else:
print(f"NOTE: Compiling {module_name} with release flags") print(f"NOTE: Compiling {module_name} with release flags")
return { return {
'cxx': ['-O3', "-DNDEBUG"], 'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"] 'nvcc': ['-O3', "-DNDEBUG"] + nvcc_extra_flags
} }
def get_ext_modules(): def get_ext_modules():
......
...@@ -51,7 +51,7 @@ ...@@ -51,7 +51,7 @@
#define THREADS (64) #define THREADS (64)
#endif #endif
#ifndef DIV_UP #ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#endif #endif
#ifndef CHECK_CUDA #ifndef CHECK_CUDA
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
...@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te ...@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop)); CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms // s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5], // [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms // s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds); // printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start)); CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop)); CHECK_CUDA(cudaEventDestroy(stop));
......
...@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>; ...@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32) #define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF) #define FULL_MASK (0xFFFFFFFF)
#define THREADS (64) #define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32) #define NNZ_TRESH (32)
......
...@@ -40,7 +40,7 @@ ...@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \ CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x) CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b)) #define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define MIN_THREADS (64) #define MIN_THREADS (64)
#define ELXTH_MAX (32) #define ELXTH_MAX (32)
......
...@@ -140,11 +140,11 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -140,11 +140,11 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T> template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
if constexpr (PSCALE != 0) { if constexpr (PSCALE != 0) {
......
...@@ -146,11 +146,11 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H ...@@ -146,11 +146,11 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, typename REAL_T> template <int BDIM_X, int ELXTH, typename REAL_T>
__global__ __global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo, __launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff, const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows, const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals, const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out) const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{ {
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out); disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
......
...@@ -35,100 +35,100 @@ void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, i ...@@ -35,100 +35,100 @@ void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, i
int64_t *roff_h, REAL_T *val_h, int64_t &nrows) int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
{ {
int64_t *Koff = new int64_t[K]; int64_t *Koff = new int64_t[K];
for (int i = 0; i < K; i++) { Koff[i] = 0; } for (int i = 0; i < K; i++) { Koff[i] = 0; }
for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; } for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
int64_t prev = Koff[0]; int64_t prev = Koff[0];
Koff[0] = 0; Koff[0] = 0;
for (int i = 1; i < K; i++) { for (int i = 1; i < K; i++) {
int64_t save = Koff[i]; int64_t save = Koff[i];
Koff[i] = prev + Koff[i - 1]; Koff[i] = prev + Koff[i - 1];
prev = save; prev = save;
} }
int64_t *ker_sort = new int64_t[nnz]; int64_t *ker_sort = new int64_t[nnz];
int64_t *row_sort = new int64_t[nnz]; int64_t *row_sort = new int64_t[nnz];
int64_t *col_sort = new int64_t[nnz]; int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz]; float *val_sort = new float[nnz];
for (int64_t i = 0; i < nnz; i++) { for (int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i]; const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++; const int64_t off = Koff[ker]++;
ker_sort[off] = ker; ker_sort[off] = ker;
row_sort[off] = row_h[i]; row_sort[off] = row_h[i];
col_sort[off] = col_h[i]; col_sort[off] = col_h[i];
val_sort[off] = val_h[i]; val_sort[off] = val_h[i];
} }
for (int64_t i = 0; i < nnz; i++) { for (int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i]; ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i]; row_h[i] = row_sort[i];
col_h[i] = col_sort[i]; col_h[i] = col_sort[i];
val_h[i] = val_sort[i]; val_h[i] = val_sort[i];
} }
delete[] Koff; delete[] Koff;
delete[] ker_sort; delete[] ker_sort;
delete[] row_sort; delete[] row_sort;
delete[] col_sort; delete[] col_sort;
delete[] val_sort; delete[] val_sort;
// compute rows offsets // compute rows offsets
nrows = 1; nrows = 1;
roff_h[0] = 0; roff_h[0] = 0;
for (int64_t i = 1; i < nnz; i++) { for (int64_t i = 1; i < nnz; i++) {
if (row_h[i - 1] == row_h[i]) continue; if (row_h[i - 1] == row_h[i]) continue;
roff_h[nrows++] = i; roff_h[nrows++] = i;
if (nrows > Ho * K) { if (nrows > Ho * K) {
fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__, fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
int64_t(Ho) * K); int64_t(Ho) * K);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
}
} }
} roff_h[nrows] = nnz;
roff_h[nrows] = nnz;
return; return;
} }
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx, torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val) torch::Tensor col_idx, torch::Tensor val)
{ {
CHECK_INPUT_TENSOR(ker_idx); CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx); CHECK_INPUT_TENSOR(row_idx);
CHECK_INPUT_TENSOR(col_idx); CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val); CHECK_INPUT_TENSOR(val);
int64_t nnz = val.size(0); int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>(); int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>(); int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>(); int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho * K + 1]; int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows; int64_t nrows;
// float *val_h = val.data_ptr<float>(); // float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] { AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h, preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
val.data_ptr<scalar_t>(), nrows); val.data_ptr<scalar_t>(), nrows);
})); }));
// create output tensor // create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype()); auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows + 1}, options); auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>(); int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; } for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h; delete[] roff_h;
return roff_idx; return roff_idx;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda."); m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment