Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
37b08bb8
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8c249d1401f12d55a59a7bdb2329b29921ae864e"
Commit
37b08bb8
authored
Jun 13, 2025
by
Max Rietmann
Browse files
Changed to qdotk_max single loop torch reference kernel
parent
a07c5b2b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
22 deletions
+20
-22
torch_harmonics/_neighborhood_attention.py
torch_harmonics/_neighborhood_attention.py
+19
-21
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
+1
-1
No files found.
torch_harmonics/_neighborhood_attention.py
View file @
37b08bb8
...
@@ -43,25 +43,28 @@ except ImportError as err:
...
@@ -43,25 +43,28 @@ except ImportError as err:
attention_cuda_extension
=
None
attention_cuda_extension
=
None
_cuda_extension_available
=
False
_cuda_extension_available
=
False
# s2 neighborhood attention forward pass
# uses qdotk_max update trick to avoid two loops when computing the softmax
# see e.g., https://arxiv.org/abs/1805.02867
# and https://alexdremov.me/understanding-flash-attention-writing-the-algorithm-from-scratch-in-triton/
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
def
_neighborhood_attention_s2_fwd_torch
(
kx
:
torch
.
Tensor
,
vx
:
torch
.
Tensor
,
qy
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
quad_weights
:
torch
.
Tensor
,
col_idx
:
torch
.
Tensor
,
row_off
:
torch
.
Tensor
,
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
nlon_in
:
int
,
nlat_out
:
int
,
nlon_out
:
int
)
->
torch
.
Tensor
:
# prepare result tensor
# prepare result tensor
y
=
torch
.
zeros_like
(
qy
)
y
=
torch
.
zeros_like
(
qy
)
for
ho
in
range
(
nlat_out
):
for
ho
in
range
(
nlat_out
):
# get number of nonzeros
# get number of nonzeros
zstart
=
row_off
[
ho
]
zstart
=
row_off
[
ho
]
zend
=
row_off
[
ho
+
1
]
zend
=
row_off
[
ho
+
1
]
for
wo
in
range
(
nlon_out
):
for
wo
in
range
(
nlon_out
):
alpha_sum
=
torch
.
zeros
((
y
.
shape
[
0
],),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
alpha_sum
=
torch
.
zeros
((
y
.
shape
[
0
],),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
qdotk_
nz
=
torch
.
zeros
((
y
.
shape
[
0
],
zend
-
zstart
,
),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
qdotk_
max
=
torch
.
zeros
((
y
.
shape
[
0
],),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
for
idz
in
range
(
zstart
,
zend
):
for
idz
in
range
(
zstart
,
zend
):
nz_col_idx
=
col_idx
[
idz
]
nz_col_idx
=
col_idx
[
idz
]
...
@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
...
@@ -75,24 +78,19 @@ def _neighborhood_attention_s2_fwd_torch(kx: torch.Tensor, vx: torch.Tensor, qy:
# compute correlation & softmax numerator
# compute correlation & softmax numerator
q_ho_wo
=
qy
[:,
:,
ho
,
wo
]
q_ho_wo
=
qy
[:,
:,
ho
,
wo
]
k_hi_wip
=
kx
[:,
:,
hi
,
wip
]
k_hi_wip
=
kx
[:,
:,
hi
,
wip
]
qdotk_nz
[:,
idz
-
zstart
]
=
torch
.
sum
(
q_ho_wo
*
k_hi_wip
,
dim
=
1
)
qdotk
=
torch
.
sum
(
q_ho_wo
*
k_hi_wip
,
dim
=
1
)
qdotk_max
,
_
=
torch
.
max
(
qdotk_nz
,
dim
=
1
)
for
idz
in
range
(
zstart
,
zend
):
nz_col_idx
=
col_idx
[
idz
]
# compute input indices from psi datastructure
# tmp max
hi
=
nz_col_idx
//
nlon_in
qdotk_max_tmp
=
torch
.
maximum
(
qdotk_max
,
qdotk
)
# account for output shift and ensure positive index due to circular condition
wi
=
nz_col_idx
%
nlon_in
wip
=
(
wi
+
wo
)
%
nlon_in
alpha
=
torch
.
exp
(
qdotk_nz
[:,
idz
-
zstart
]
-
qdotk_max
)
# softmax denominator
alpha_sum
[:]
+=
alpha
[:]
*
quad_weights
[
hi
]
y
[:,:,
ho
,
wo
]
+=
alpha
[:,
None
]
*
vx
[:,:,
hi
,
wip
]
*
quad_weights
[
hi
]
# alpha sum update
alpha
=
torch
.
exp
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
]
alpha_sum
=
alpha
+
alpha_sum
*
torch
.
exp
(
qdotk_max
-
qdotk_max_tmp
)
# update output
y
[:,:,
ho
,
wo
]
=
y
[:,:,
ho
,
wo
]
*
torch
.
exp
(
qdotk_max
-
qdotk_max_tmp
).
unsqueeze
(
1
)
+
alpha
[:,
None
]
*
vx
[:,:,
hi
,
wip
]
# define new max
qdotk_max
=
qdotk_max_tmp
y
[:,:,
ho
,
wo
]
=
y
[:,:,
ho
,
wo
]
/
alpha_sum
[:,
None
]
y
[:,:,
ho
,
wo
]
=
y
[:,:,
ho
,
wo
]
/
alpha_sum
[:,
None
]
...
...
torch_harmonics/csrc/attention/attention_fwd_cuda.cu
View file @
37b08bb8
...
@@ -256,7 +256,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
...
@@ -256,7 +256,7 @@ torch::Tensor s2_attention_fwd_cuda(at::Tensor kx,
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventRecord
(
stop
,
stream
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
// printf("s2_attention_kernel_
mbT
execution time: %f ms\n", milliseconds);
// printf("s2_attention_kernel_
fwd
execution time: %f ms\n", milliseconds);
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment