Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
07fa44d6
Commit
07fa44d6
authored
Jul 11, 2025
by
Mauro Bisson
Committed by
Thorsten Kurth
Jul 15, 2025
Browse files
Added missing kernel launch error checks.
parent
9689109f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
0 deletions
+13
-0
setup.py
setup.py
+1
-0
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+4
-0
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+4
-0
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+4
-0
No files found.
setup.py
View file @
07fa44d6
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags
=
[]
nvcc_extra_flags
=
[]
if
profile_mode
:
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-Xptxas=-v"
)
if
debug_mode
:
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
07fa44d6
...
@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size,
...
@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_generic_vec_k"
);
return
;
return
;
}
}
...
@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size,
...
@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_special_vec_k"
);
return
;
return
;
}
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
07fa44d6
...
@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size,
...
@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size,
s2_attn_fwd_generic_vec_k
<
THREADS
>
s2_attn_fwd_generic_vec_k
<
THREADS
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_generic_vec_k"
);
return
;
return
;
}
}
...
@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size,
...
@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size,
s2_attn_fwd_special_vec_k
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
>
s2_attn_fwd_special_vec_k
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_special_vec_k"
);
return
;
return
;
}
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
...
...
torch_harmonics/csrc/attention/attention_utils.cu
View file @
07fa44d6
...
@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
...
@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src
.
size
(
3
),
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
...
@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
...
@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src
.
size
(
3
),
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
}
return
dst
;
return
dst
;
...
@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
...
@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src
.
size
(
2
),
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
...
@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
...
@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src
.
size
(
2
),
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
}
return
dst
;
return
dst
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment