Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1dc91af9
Commit
1dc91af9
authored
Sep 15, 2022
by
wangshaojie6
Browse files
add template to distinguish masking kernel
parent
7b18e6fd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
31 additions
and
7 deletions
+31
-7
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+26
-5
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
1dc91af9
...
@@ -117,7 +117,8 @@ using DeviceGemmInstance =
...
@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
true
>
;
// OnlyLowerTriangle
// Ref Gemm0: fp16 in, fp32 out
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemmUpperTriangleMinusInf
<
ADataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemmUpperTriangleMinusInf
<
ADataType
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
1dc91af9
...
@@ -168,6 +168,7 @@ template <typename ALayout,
...
@@ -168,6 +168,7 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
OnlyLowerTriangle
=
false
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
...
@@ -498,7 +499,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -498,7 +499,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
matrix_padder
.
PadN
>
;
matrix_padder
.
PadN
,
OnlyLowerTriangle
>
;
// Argument
// Argument
// FIXME: constness
// FIXME: constness
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
1dc91af9
...
@@ -76,7 +76,8 @@ template <typename FloatAB,
...
@@ -76,7 +76,8 @@ template <typename FloatAB,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
bool
PadN
>
bool
PadN
,
bool
OnlyLowerTriangle
=
false
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
@@ -756,8 +757,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -756,8 +757,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// decoder lower triangular mask
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
...
@@ -766,9 +767,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -766,9 +767,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
do
do
{
{
if
((
m_block_data_idx_on_grid
<
gemm1_k_block_outer_index
*
NPerBlock
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm1_k_block_outer_index
*
NPerBlock
+
NPerBlock
-
1
))
)
if
constexpr
(
OnlyLowerTriangle
)
{
{
continue
;
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
{
continue
;
}
}
}
// gemm0
// gemm0
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
...
@@ -787,6 +792,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -787,6 +792,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf
,
acc_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
if
constexpr
(
!
OnlyLowerTriangle
)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc_thread_buf
(
i
),
acc_thread_buf
[
i
]);
});
#else
static_for
<
0
,
acc_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
i
),
acc_element_op
,
acc_thread_buf
[
i
]);
});
#endif
}
else
{
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
...
@@ -821,6 +841,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -821,6 +841,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
});
});
});
});
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment