Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
191ba149
Commit
191ba149
authored
Jul 03, 2025
by
Max Rietmann
Browse files
Re-introduce inline softmax from main
parent
3dd35b45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
21 deletions
+10
-21
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
+10
-21
No files found.
torch_harmonics/csrc/attention/attention_bwd_cuda.cu
View file @
191ba149
...
@@ -159,21 +159,7 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
...
@@ -159,21 +159,7 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int64_t
rend
=
psi_row_offset
[
ho
+
1
];
const
int
rlen
=
rend
-
rbeg
;
const
int
rlen
=
rend
-
rbeg
;
// First pass: find qdotk_max
// 1st pass: accumulate alpha_sum, integral, and shared stats, along with a progressively computed qdotk_max.
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
wi
=
col
-
(
hi
*
nlon_in
);
const
int
wip
=
(
wi
+
wo
)
-
((
wi
+
wo
)
/
nlon_in
)
*
nlon_in
;
float
qdotk
=
0.0
f
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
qdotk
+=
sh_qy
[
chan
]
*
kx
[
batchId
][
chan
][
hi
][
wip
];
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk_max
=
max
(
qdotk_max
,
qdotk
);
}
// Second pass: accumulate alpha_sum, integral, and shared stats
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
for
(
int
off
=
0
;
off
<
rlen
;
off
++
)
{
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int64_t
col
=
psi_col_idx
[
rbeg
+
off
];
const
int
hi
=
col
/
nlon_in
;
const
int
hi
=
col
/
nlon_in
;
...
@@ -186,15 +172,18 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
...
@@ -186,15 +172,18 @@ __global__ __launch_bounds__(BDIM_X) void s2_attention_bwd_dkvq_kernel(
}
}
qdotk
=
__warp_sum_cub
(
qdotk
);
qdotk
=
__warp_sum_cub
(
qdotk
);
gdotv
=
__warp_sum_cub
(
gdotv
);
gdotv
=
__warp_sum_cub
(
gdotv
);
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max
)
*
quad_weights
[
hi
];
float
qdotk_max_tmp
=
max
(
qdotk_max
,
qdotk
);
alpha_sum
+=
alpha_inz
;
float
alpha_inz
=
expf
(
qdotk
-
qdotk_max_tmp
)
*
quad_weights
[
hi
];
integral
+=
alpha_inz
*
gdotv
;
float
max_correction
=
expf
(
qdotk_max
-
qdotk_max_tmp
);
alpha_sum
=
alpha_sum
*
max_correction
+
alpha_inz
;
integral
=
integral
*
max_correction
+
alpha_inz
*
gdotv
;
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
for
(
int
chan
=
tidx
;
chan
<
num_channels
;
chan
+=
WARP_SIZE
)
{
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
float
kxval
=
kx
[
batchId
][
chan
][
hi
][
wip
];
sh_alpha_k
[
chan
]
+
=
alpha_inz
*
kxval
;
sh_alpha_k
[
chan
]
=
sh_alpha_k
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
;
sh_alpha_vw
[
chan
]
+
=
alpha_inz
*
gdotv
;
sh_alpha_vw
[
chan
]
=
sh_alpha_vw
[
chan
]
*
max_correction
+
alpha_inz
*
gdotv
;
sh_alpha_kvw
[
chan
]
+
=
alpha_inz
*
kxval
*
gdotv
;
sh_alpha_kvw
[
chan
]
=
sh_alpha_kvw
[
chan
]
*
max_correction
+
alpha_inz
*
kxval
*
gdotv
;
}
}
qdotk_max
=
qdotk_max_tmp
;
}
}
integral
/=
alpha_sum
;
integral
/=
alpha_sum
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment