Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
65058287
Commit
65058287
authored
Jun 18, 2025
by
Max Rietmann
Browse files
Merge formatting changes
parents
c46b6925
76836abf
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
106 additions
and
104 deletions
+106
-104
setup.py
setup.py
+8
-2
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+2
-6
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+1
-1
torch_harmonics/csrc/disco/disco_cuda.cuh
torch_harmonics/csrc/disco/disco_cuda.cuh
+1
-1
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
+5
-5
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
+5
-5
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+84
-84
No files found.
setup.py
View file @
65058287
...
...
@@ -56,17 +56,23 @@ except (ImportError, TypeError, AssertionError, AttributeError) as e:
def
get_compile_args
(
module_name
):
"""If user runs build with TORCH_HARMONICS_DEBUG=1 set, it will use debugging flags to build"""
debug_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_DEBUG'
,
'0'
)
==
'1'
profile_mode
=
os
.
environ
.
get
(
'TORCH_HARMONICS_PROFILE'
,
'0'
)
==
'1'
nvcc_extra_flags
=
[]
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
return
{
'cxx'
:
[
'-g'
,
'-O0'
,
'-Wall'
],
'nvcc'
:
[
'-g'
,
'-G'
,
'-O0'
]
'nvcc'
:
[
'-g'
,
'-G'
,
'-O0'
]
+
nvcc_extra_flags
}
else
:
print
(
f
"NOTE: Compiling
{
module_name
}
with release flags"
)
return
{
'cxx'
:
[
'-O3'
,
"-DNDEBUG"
],
'nvcc'
:
[
'-O3'
,
"-DNDEBUG"
]
'nvcc'
:
[
'-O3'
,
"-DNDEBUG"
]
+
nvcc_extra_flags
}
def
get_ext_modules
():
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
65058287
...
...
@@ -51,7 +51,7 @@
#define THREADS (64)
#endif
#ifndef DIV_UP
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#endif
#ifndef CHECK_CUDA
#define CHECK_CUDA(call) \
...
...
@@ -312,14 +312,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> s2_attention_bwd_dkvq_cuda(at::Te
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// s2_attention_bwd_kernel execution time: 51.231743 ms
// s2_attention_bwd_kernel execution time: 52.971519 ms
// s2_attention_bwd_kernel execution time: 50.724865 ms
// [1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel execution time: 11.679744 ms
printf
(
"s2_attention_bwd_kernel execution time: %f ms
\n
"
,
milliseconds
);
//
printf("s2_attention_bwd_kernel execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
65058287
...
...
@@ -45,7 +45,7 @@ using BlockReduceFloat512 = cub::BlockReduce<float, 512>;
#define WARP_SIZE (32)
#define FULL_MASK (0xFFFFFFFF)
#define THREADS (64)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define NNZ_TRESH (32)
...
...
torch_harmonics/csrc/disco/disco_cuda.cuh
View file @
65058287
...
...
@@ -40,7 +40,7 @@
CHECK_CUDA_TENSOR(x); \
CHECK_CONTIGUOUS_TENSOR(x)
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define DIV_UP(a, b) (((a) + ((b)
-
1)) / (b))
#define MIN_THREADS (64)
#define ELXTH_MAX (32)
...
...
torch_harmonics/csrc/disco/disco_cuda_bwd.cu
View file @
65058287
...
...
@@ -140,11 +140,11 @@ __device__ void disco_bwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
int
PSCALE
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
__launch_bounds__
(
BDIM_X
)
void
disco_bwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
if
constexpr
(
PSCALE
!=
0
)
{
...
...
torch_harmonics/csrc/disco/disco_cuda_fwd.cu
View file @
65058287
...
...
@@ -146,11 +146,11 @@ __device__ void disco_fwd_d(const int Hi, const int Wi, const int K, const int H
template
<
int
BDIM_X
,
int
ELXTH
,
typename
REAL_T
>
__global__
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
__launch_bounds__
(
BDIM_X
)
void
disco_fwd_blk_k
(
const
int
Hi
,
const
int
Wi
,
const
int
K
,
const
int
Ho
,
const
int
Wo
,
const
int
pscale
,
const
int64_t
*
__restrict__
roff
,
const
int64_t
*
__restrict__
kers
,
const
int64_t
*
__restrict__
rows
,
const
int64_t
*
__restrict__
cols
,
const
REAL_T
*
__restrict__
vals
,
const
REAL_T
*
__restrict__
inp
,
REAL_T
*
__restrict__
out
)
{
disco_fwd_d
<
BDIM_X
,
ELXTH
>
(
Hi
,
Wi
,
K
,
Ho
,
Wo
,
pscale
,
roff
,
kers
,
rows
,
cols
,
vals
,
inp
,
out
);
...
...
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
65058287
...
...
@@ -35,100 +35,100 @@ void preprocess_psi_kernel(int64_t nnz, int64_t K, int64_t Ho, int64_t *ker_h, i
int64_t
*
roff_h
,
REAL_T
*
val_h
,
int64_t
&
nrows
)
{
int64_t
*
Koff
=
new
int64_t
[
K
];
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
int64_t
prev
=
Koff
[
0
];
Koff
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
int64_t
save
=
Koff
[
i
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
prev
=
save
;
}
int64_t
*
ker_sort
=
new
int64_t
[
nnz
];
int64_t
*
row_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
off
=
Koff
[
ker
]
++
;
ker_sort
[
off
]
=
ker
;
row_sort
[
off
]
=
row_h
[
i
];
col_sort
[
off
]
=
col_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
ker_h
[
i
]
=
ker_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
}
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
// compute rows offsets
nrows
=
1
;
roff_h
[
0
]
=
0
;
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
roff_h
[
nrows
++
]
=
i
;
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
exit
(
EXIT_FAILURE
);
int64_t
*
Koff
=
new
int64_t
[
K
];
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
Koff
[
i
]
=
0
;
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
Koff
[
ker_h
[
i
]]
++
;
}
int64_t
prev
=
Koff
[
0
];
Koff
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<
K
;
i
++
)
{
int64_t
save
=
Koff
[
i
];
Koff
[
i
]
=
prev
+
Koff
[
i
-
1
];
prev
=
save
;
}
int64_t
*
ker_sort
=
new
int64_t
[
nnz
];
int64_t
*
row_sort
=
new
int64_t
[
nnz
];
int64_t
*
col_sort
=
new
int64_t
[
nnz
];
float
*
val_sort
=
new
float
[
nnz
];
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
const
int64_t
ker
=
ker_h
[
i
];
const
int64_t
off
=
Koff
[
ker
]
++
;
ker_sort
[
off
]
=
ker
;
row_sort
[
off
]
=
row_h
[
i
];
col_sort
[
off
]
=
col_h
[
i
];
val_sort
[
off
]
=
val_h
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
nnz
;
i
++
)
{
ker_h
[
i
]
=
ker_sort
[
i
];
row_h
[
i
]
=
row_sort
[
i
];
col_h
[
i
]
=
col_sort
[
i
];
val_h
[
i
]
=
val_sort
[
i
];
}
delete
[]
Koff
;
delete
[]
ker_sort
;
delete
[]
row_sort
;
delete
[]
col_sort
;
delete
[]
val_sort
;
// compute rows offsets
nrows
=
1
;
roff_h
[
0
]
=
0
;
for
(
int64_t
i
=
1
;
i
<
nnz
;
i
++
)
{
if
(
row_h
[
i
-
1
]
==
row_h
[
i
])
continue
;
roff_h
[
nrows
++
]
=
i
;
if
(
nrows
>
Ho
*
K
)
{
fprintf
(
stderr
,
"%s:%d: error, found more rows in the K COOs than Ho*K (%ld)
\n
"
,
__FILE__
,
__LINE__
,
int64_t
(
Ho
)
*
K
);
exit
(
EXIT_FAILURE
);
}
}
}
roff_h
[
nrows
]
=
nnz
;
roff_h
[
nrows
]
=
nnz
;
return
;
return
;
}
torch
::
Tensor
preprocess_psi
(
const
int64_t
K
,
const
int64_t
Ho
,
torch
::
Tensor
ker_idx
,
torch
::
Tensor
row_idx
,
torch
::
Tensor
col_idx
,
torch
::
Tensor
val
)
{
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
return
roff_idx
;
CHECK_INPUT_TENSOR
(
ker_idx
);
CHECK_INPUT_TENSOR
(
row_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
val
.
data_ptr
<
scalar_t
>
(),
nrows
);
}));
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
return
roff_idx
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
m
.
def
(
"preprocess_psi"
,
&
preprocess_psi
,
"Sort psi matrix, required for using disco_cuda."
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment