Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
744d2269
Unverified
Commit
744d2269
authored
Jul 07, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 07, 2025
Browse files
Merge pull request #87 from NVIDIA/maurob/devel
Use 64-bit for pointer offsets
parents
1eed5673
c9ae23d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
15 deletions
+14
-15
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+14
-15
No files found.
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
744d2269
...
@@ -214,10 +214,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
...
@@ -214,10 +214,10 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
shy
[
chan
]
=
__vset
<
FLOATV_T
>
(
0.
f
);
shy
[
chan
]
=
__vset
<
FLOATV_T
>
(
0.
f
);
}
}
kx
+=
batch
*
nlat_in
*
nlon_in
*
nchan
;
kx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
vx
+=
batch
*
nlat_in
*
nlon_in
*
nchan
;
vx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
;
qy
+=
batch
*
nlat_out
*
nlon_out
*
nchan
+
ho
*
nchan
*
nlon_out
+
wo
*
nchan
;
qy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nchan
*
nlon_out
+
int64_t
(
wo
)
*
nchan
;
y
+=
batch
*
nlat_out
*
nlon_out
*
nchan
+
ho
*
nchan
*
nlon_out
+
wo
*
nchan
;
y
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nchan
*
nlon_out
+
int64_t
(
wo
)
*
nchan
;
float
alpha_sum
=
0.0
f
;
float
alpha_sum
=
0.0
f
;
float
qdotk_max
=
-
FLT_MAX
;
float
qdotk_max
=
-
FLT_MAX
;
...
@@ -235,8 +235,8 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
...
@@ -235,8 +235,8 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
hi
*
nlon_in
*
nchan
+
wip
*
nchan
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
hi
*
nlon_in
*
nchan
+
wip
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotkv
=
__vset
<
FLOATV_T
>
(
0.
f
);
FLOATV_T
qdotkv
=
__vset
<
FLOATV_T
>
(
0.
f
);
...
@@ -275,7 +275,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
...
@@ -275,7 +275,6 @@ void s2_attn_fwd_generic_vec_k(int nchan, // no. of FLOATV_T elements along cha
template
<
typename
FLOATV_T
>
template
<
typename
FLOATV_T
>
void
launch_gen_attn_kernel
(
int
batch_size
,
void
launch_gen_attn_kernel
(
int
batch_size
,
int
nloc
,
int
nchans
,
int
nchans
,
int
nlat_in
,
int
nlat_in
,
int
nlon_in
,
int
nlon_in
,
...
@@ -352,10 +351,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
...
@@ -352,10 +351,10 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const
int
wo
=
ctaid
-
(
h
*
nlon_out
);
const
int
wo
=
ctaid
-
(
h
*
nlon_out
);
const
int
ho
=
row_idx
[
h
];
const
int
ho
=
row_idx
[
h
];
kx
+=
batch
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
kx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
vx
+=
batch
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
vx
+=
int64_t
(
batch
)
*
nlat_in
*
nlon_in
*
nchan
+
tidx
;
qy
+=
batch
*
nlat_out
*
nlon_out
*
nchan
+
ho
*
nlon_out
*
nchan
+
wo
*
nchan
+
tidx
;
qy
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
+
tidx
;
y
+=
batch
*
nlat_out
*
nlon_out
*
nchan
+
ho
*
nlon_out
*
nchan
+
wo
*
nchan
+
tidx
;
y
+=
int64_t
(
batch
)
*
nlat_out
*
nlon_out
*
nchan
+
int64_t
(
ho
)
*
nlon_out
*
nchan
+
int64_t
(
wo
)
*
nchan
+
tidx
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NLOC
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NLOC
;
i
++
)
{
...
@@ -386,8 +385,8 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
...
@@ -386,8 +385,8 @@ void s2_attn_fwd_special_vec_k(int nchan, // no. of FLOATV_T elements along chan
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
const
FLOATV_T
*
_kx
=
kx
+
hi
*
nlon_in
*
nchan
+
wip
*
nchan
;
const
FLOATV_T
*
_kx
=
kx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
hi
*
nlon_in
*
nchan
+
wip
*
nchan
;
const
FLOATV_T
*
_vx
=
vx
+
int64_t
(
hi
)
*
nlon_in
*
nchan
+
int64_t
(
wip
)
*
nchan
;
FLOATV_T
qdotkv
=
__vset
<
FLOATV_T
>
(
0.
f
);
FLOATV_T
qdotkv
=
__vset
<
FLOATV_T
>
(
0.
f
);
...
@@ -626,7 +625,7 @@ static void s2_attention_dipatch(int batch_size,
...
@@ -626,7 +625,7 @@ static void s2_attention_dipatch(int batch_size,
case
256
:
launch_spc_attn_kernel
<
256
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
256
:
launch_spc_attn_kernel
<
256
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
512
:
launch_spc_attn_kernel
<
512
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
512
:
launch_spc_attn_kernel
<
512
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
1024
:
launch_spc_attn_kernel
<
1024
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
case
1024
:
launch_spc_attn_kernel
<
1024
,
1
,
1
,
MAX_LOCAL_ARR_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
default:
launch_gen_attn_kernel
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
default:
launch_gen_attn_kernel
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp
,
_vxp
,
_qyp
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp
,
stream
);
break
;
}
}
}
else
{
}
else
{
...
@@ -649,7 +648,7 @@ static void s2_attention_dipatch(int batch_size,
...
@@ -649,7 +648,7 @@ static void s2_attention_dipatch(int batch_size,
case
256
:
launch_spc_attn_kernel
<
256
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
256
:
launch_spc_attn_kernel
<
256
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
512
:
launch_spc_attn_kernel
<
512
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
512
:
launch_spc_attn_kernel
<
512
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
1024
:
launch_spc_attn_kernel
<
1024
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
case
1024
:
launch_spc_attn_kernel
<
1024
,
1
,
1
,
MAX_LOCAL_VEC_LEN
>
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
default:
launch_gen_attn_kernel
(
batch_size
,
nloc
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
default:
launch_gen_attn_kernel
(
batch_size
,
nchans
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
_kxp4
,
_vxp4
,
_qyp4
,
row_idx
,
row_off
,
col_idx
,
quad_weights
,
_yp4
,
stream
);
break
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment