Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
54dfedcd
"examples/unconditional_image_generation/README.md" did not exist on "85244d4a5901cc5062562b15361b06fdbeebf528"
Commit
54dfedcd
authored
Feb 20, 2023
by
guangzlu
Browse files
added switch for lse storing in attn fwd
parent
d2eed8e6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
15 deletions
+19
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+19
-15
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
54dfedcd
...
...
@@ -416,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
template
<
bool
HasMainKBlockLoop
,
bool
IsDropout
,
bool
IsLseStoring
,
typename
Block2CTileMap
,
typename
C0MatrixMask
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
@@ -1019,7 +1020,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
else
{
...
...
@@ -1149,22 +1149,26 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// Calculate max + ln(sum) and write out
static_for
<
0
,
MXdlPerWave
,
1
>
{}(
[
&
](
auto
I
)
{
lse_thread_buf
(
I
)
=
running_max
(
I
)
+
math
::
log
(
running_sum
(
I
));
});
if
(
get_warp_local_1d_id
()
<
AccM2
)
if
constexpr
(
IsLseStoring
)
{
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
I
)
{
// copy from VGPR to Global
lse_thread_copy_vgpr_to_global
.
Run
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
Number
<
I
>
{},
I0
,
I0
),
lse_thread_buf
,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
);
lse_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
0
,
1
,
0
,
0
));
});
static_for
<
0
,
MXdlPerWave
,
1
>
{}(
[
&
](
auto
I
)
{
lse_thread_buf
(
I
)
=
running_max
(
I
)
+
math
::
log
(
running_sum
(
I
));
});
if
(
get_warp_local_1d_id
()
<
AccM2
)
{
static_for
<
0
,
MXdlPerWave
,
1
>
{}([
&
](
auto
I
)
{
// copy from VGPR to Global
lse_thread_copy_vgpr_to_global
.
Run
(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl
,
make_tuple
(
I0
,
Number
<
I
>
{},
I0
,
I0
),
lse_thread_buf
,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_buf
);
lse_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
0
,
1
,
0
,
0
));
});
}
}
// shuffle C and write out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment