Commit 65058287 authored by Max Rietmann's avatar Max Rietmann
Browse files

Merge formatting changes

parents c46b6925 76836abf
......@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def get_compile_args(module_name):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode = os.environ.get('TORCH_HARMONICS_DEBUG', '0') == '1'
profile_mode = os.environ.get('TORCH_HARMONICS_PROFILE', '0') == '1'
nvcc_extra_flags = []
if profile_mode:
nvcc_extra_flags.append("-lineinfo")
if debug_mode:
print(f"WARNING: Compiling {module_name} with debugging flags")
return {
'cxx': ['-g', '-O0', '-Wall'],
'nvcc': ['-g', '-G', '-O0']
'nvcc': ['-g', '-G', '-O0'] + nvcc_extra_flags
}
else:
print(f"NOTE: Compiling {module_name} with release flags")
return {
'cxx': ['-O3', "-DNDEBUG"],
'nvcc': ['-O3', "-DNDEBUG"]
'nvcc': ['-O3', "-DNDEBUG"] + nvcc_extra_flags
}
def get_ext_modules():
......
......@@ -51,7 +51,7 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
......@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
// printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
......
......@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define NNZ_TRESH (32)
......
......@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b) - 1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)-1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
......
......@@ -140,7 +140,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, int PSCALE, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
__launch_bounds__(BDIM_X) void disco_bwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
......
......@@ -146,7 +146,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template <int BDIM_X, int ELXTH, typename REAL_T>
__global__
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
__launch_bounds__(BDIM_X) void disco_fwd_blk_k(const int Hi, const int Wi, const int K, const int Ho, const int Wo,
const int pscale, const int64_t *__restrict__ roff,
const int64_t *__restrict__ kers, const int64_t *__restrict__ rows,
const int64_t *__restrict__ cols, const REAL_T *__restrict__ vals,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment