Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
336a7065
Commit
336a7065
authored
Sep 14, 2022
by
danyao12
Browse files
add decoder lower triangular mask calculation
parent
3f9100cc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
9 deletions
+42
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+42
-9
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
100644 → 100755
View file @
336a7065
...
...
@@ -749,6 +749,15 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
do
...
...
@@ -770,16 +779,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf
,
num_k_block_main_loop
);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
n3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
(
n_global
>
m_global
)
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
#else
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
acc_offset
),
acc_element_op
,
acc_thread_buf
[
acc_offset
]);
#endif
}
});
});
});
});
#endif
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment