Commit 65058287 authored by Max Rietmann's avatar Max Rietmann
Browse files

Merge formatting changes

parents c46b6925 76836abf
......@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
nvcc_extra_flags = []
if profile_mode:
nvcc_extra_flags.append("-lineinfo")
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
'nvcc': ['-g', '-G', '-O0'] + nvcc_extra_flags
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
'nvcc': ['-O3', "-DNDEBUG"] + nvcc_extra_flags
}
def get_ext_modules():
......
......@@ -51,7 +51,7 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
......@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
......
......@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32)
......
......@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
......
......@@ -140,11 +140,11 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
if constexpr (PSCALE != 0) {
......
......@@ -146,11 +146,11 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
const REAL_T *__restrict__ inp, REAL_T *__restrict__ out)
{
disco_fwd_d<BDIM_X, ELXTH>(Hi, Wi, K, Ho, Wo, pscale, roff, kers, rows, cols, vals, inp, out);
......
......@@ -35,100 +35,100 @@ void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, i
int64_t *roff_h, REAL_T *val_h, int64_t &nrows)
{
int64_t *Koff = new int64_t[K];
for (int i = 0; i < K; i++) { Koff[i] = 0; }
for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
int64_t prev = Koff[0];
Koff[0] = 0;
for (int i = 1; i < K; i++) {
int64_t save = Koff[i];
Koff[i] = prev + Koff[i - 1];
prev = save;
}
int64_t *ker_sort = new int64_t[nnz];
int64_t *row_sort = new int64_t[nnz];
int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz];
for (int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++;
ker_sort[off] = ker;
row_sort[off] = row_h[i];
col_sort[off] = col_h[i];
val_sort[off] = val_h[i];
}
for (int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i];
col_h[i] = col_sort[i];
val_h[i] = val_sort[i];
}
delete[] Koff;
delete[] ker_sort;
delete[] row_sort;
delete[] col_sort;
delete[] val_sort;
// compute rows offsets
nrows = 1;
roff_h[0] = 0;
for (int64_t i = 1; i < nnz; i++) {
if (row_h[i - 1] == row_h[i]) continue;
roff_h[nrows++] = i;
if (nrows > Ho * K) {
fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
int64_t(Ho) * K);
exit(EXIT_FAILURE);
int64_t *Koff = new int64_t[K];
for (int i = 0; i < K; i++) { Koff[i] = 0; }
for (int64_t i = 0; i < nnz; i++) { Koff[ker_h[i]]++; }
int64_t prev = Koff[0];
Koff[0] = 0;
for (int i = 1; i < K; i++) {
int64_t save = Koff[i];
Koff[i] = prev + Koff[i - 1];
prev = save;
}
int64_t *ker_sort = new int64_t[nnz];
int64_t *row_sort = new int64_t[nnz];
int64_t *col_sort = new int64_t[nnz];
float *val_sort = new float[nnz];
for (int64_t i = 0; i < nnz; i++) {
const int64_t ker = ker_h[i];
const int64_t off = Koff[ker]++;
ker_sort[off] = ker;
row_sort[off] = row_h[i];
col_sort[off] = col_h[i];
val_sort[off] = val_h[i];
}
for (int64_t i = 0; i < nnz; i++) {
ker_h[i] = ker_sort[i];
row_h[i] = row_sort[i];
col_h[i] = col_sort[i];
val_h[i] = val_sort[i];
}
delete[] Koff;
delete[] ker_sort;
delete[] row_sort;
delete[] col_sort;
delete[] val_sort;
// compute rows offsets
nrows = 1;
roff_h[0] = 0;
for (int64_t i = 1; i < nnz; i++) {
if (row_h[i - 1] == row_h[i]) continue;
roff_h[nrows++] = i;
if (nrows > Ho * K) {
fprintf(stderr, "%s:%d: error, found more rows in the K COOs than Ho*K (%ld)\n", __FILE__, __LINE__,
int64_t(Ho) * K);
exit(EXIT_FAILURE);
}
}
}
roff_h[nrows] = nnz;
roff_h[nrows] = nnz;
return;
return;
}
torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ker_idx, torch::Tensor row_idx,
torch::Tensor col_idx, torch::Tensor val)
{
CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx);
CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val);
int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
val.data_ptr<scalar_t>(), nrows);
}));
// create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h;
return roff_idx;
CHECK_INPUT_TENSOR(ker_idx);
CHECK_INPUT_TENSOR(row_idx);
CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val);
int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>();
int64_t *col_h = col_idx.data_ptr<int64_t>();
int64_t *roff_h = new int64_t[Ho * K + 1];
int64_t nrows;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES(val.scalar_type(), "preprocess_psi", ([&] {
preprocess_psi_kernel<scalar_t>(nnz, K, Ho, ker_h, row_h, col_h, roff_h,
val.data_ptr<scalar_t>(), nrows);
}));
// create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype());
auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h;
return roff_idx;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
m.def("preprocess_psi", &preprocess_psi, "Sort psi matrix, required for using disco_cuda.");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment