Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
07fa44d6
"vscode:/vscode.git/clone" did not exist on "dfaf2b20fb7c7e2f24553342af27ef67989426a7"
Commit
07fa44d6
authored
Jul 11, 2025
by
Mauro Bisson
Committed by
Thorsten Kurth
Jul 15, 2025
Browse files
Added missing kernel launch error checks.
parent
9689109f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
0 deletions
+13
-0
setup.py
setup.py
+1
-0
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+4
-0
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+4
-0
torch_harmonics/csrc/attention/attention_utils.cu
torch_harmonics/csrc/attention/attention_utils.cu
+4
-0
No files found.
setup.py
View file @
07fa44d6
...
...
@@ -61,6 +61,7 @@ def get_compile_args(module_name):
nvcc_extra_flags
=
[]
if
profile_mode
:
nvcc_extra_flags
.
append
(
"-lineinfo"
)
nvcc_extra_flags
.
append
(
"-Xptxas=-v"
)
if
debug_mode
:
print
(
f
"WARNING: Compiling
{
module_name
}
with debugging flags"
)
...
...
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
07fa44d6
...
...
@@ -715,6 +715,8 @@ void launch_gen_attn_bwd(int batch_size,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_generic_vec_k"
);
return
;
}
...
...
@@ -754,6 +756,8 @@ void launch_spc_attn_bwd(int batch_size,
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_dyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_dkxp
,
_dvxp
,
_dqyp
);
CHECK_ERROR
(
"s2_attn_bwd_special_vec_k"
);
return
;
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
07fa44d6
...
...
@@ -311,6 +311,8 @@ void launch_gen_attn_fwd(int batch_size,
s2_attn_fwd_generic_vec_k
<
THREADS
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_generic_vec_k"
);
return
;
}
...
...
@@ -346,6 +348,8 @@ void launch_spc_attn_fwd(int batch_size,
s2_attn_fwd_special_vec_k
<
BDIM_X
,
BDIM_Y
,
CUR_LOC_SIZE
>
<<<
grid
,
block
,
shsize
,
stream
>>>
(
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
_row_idx
,
_row_off
,
_col_idx
,
_quad_weights
,
_yp
);
CHECK_ERROR
(
"s2_attn_fwd_special_vec_k"
);
return
;
}
if
constexpr
(
CUR_LOC_SIZE
<
MAX_LOC_SIZE
)
{
...
...
torch_harmonics/csrc/attention/attention_utils.cu
View file @
07fa44d6
...
...
@@ -206,6 +206,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0231_k_tile_generic"
);
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0231_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
...
...
@@ -214,6 +215,7 @@ at::Tensor permute_4D_floatT_to0231(at::Tensor src, cudaStream_t stream) {
src
.
size
(
3
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0231_k_tile_sm100"
);
}
return
dst
;
...
...
@@ -306,6 +308,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0312_k_tile_generic"
);
}
else
{
block
.
y
=
TRANSP_WARPS_X_TILE_SM100
;
permute_to0312_k
<
WARP_SIZE
,
TRANSP_WARPS_X_TILE_SM100
>
...
...
@@ -314,6 +317,7 @@ at::Tensor permute_4D_floatT_to0312(at::Tensor src, cudaStream_t stream) {
src
.
size
(
2
),
src
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
dst
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
());
CHECK_ERROR
(
"permute_to0312_k_tile_sm100"
);
}
return
dst
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment