Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
65058287
Commit
65058287
authored
Jun 18, 2025
by
Max Rietmann
Browse files
Merge formatting changes
parents
c46b6925
76836abf
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
22 additions
and
20 deletions
+22
-20
setup.py
setup.py
+8
-2
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+2
-6
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+1
-1
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+1
-1
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+5
-5
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+5
-5
No files found.
setup.py
View file @
65058287
...
@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
...
@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def
get_compile_args
(
module_name
):
def
get_compile_args
(
module_name
):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
nvcc_extra_flags
=
[]
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
if
debug_mode
:
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
return
{
return
{
'cxx'
:
[
'-g'
,
'-O0'
,
'-Wall'
],
'cxx'
:
[
'-g'
,
'-O0'
,
'-Wall'
],
'nvcc'
:
[
'-g'
,
'-G'
,
'-O0'
]
'nvcc'
:
[
'-g'
,
'-G'
,
'-O0'
]
+
nvcc_extra_flags
}
}
else
:
else
:
print
(
f
"NOTE: Compiling
{
module_name
}
with release flags"
)
print
(
f
"NOTE: Compiling
{
module_name
}
with release flags"
)
return
{
return
{
'cxx'
:
[
'-O3'
,
"-DNDEBUG"
],
'cxx'
:
[
'-O3'
,
"-DNDEBUG"
],
'nvcc'
:
[
'-O3'
,
"-DNDEBUG"
]
'nvcc'
:
[
'-O3'
,
"-DNDEBUG"
]
+
nvcc_extra_flags
}
}
def
get_ext_modules
():
def
get_ext_modules
():
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
65058287
...
@@ -51,7 +51,7 @@
...
@@ -51,7 +51,7 @@
#define THREADS (64)
#define THREADS (64)
#endif
#endif
#ifndef DIV_UP
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#endif
#ifndef CHECK_CUDA
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
#define CHECK_CUDA(call) \
...
@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
...
@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf
(
"s2_attention_bwd_kernel execution time: %f ms
\n
"
,
milliseconds
);
//
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
65058287
...
@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
...
@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define NNZ_TRESH (32)
#define NNZ_TRESH (32)
...
...
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
65058287
...
@@ -40,7 +40,7 @@
...
@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define MIN_THREADS (64)
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
#define ELXTH_MAX (32)
...
...
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
65058287
...
@@ -140,7 +140,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
...
@@ -140,7 +140,7 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
...
...
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
65058287
...
@@ -146,7 +146,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
...
@@ -146,7 +146,7 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment