Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7ae26b79
Commit
7ae26b79
authored
Sep 15, 2022
by
wangshaojie6
Browse files
rename template and remove default template value
parent
1dc91af9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
7 deletions
+8
-7
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
...gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+3
-3
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
7ae26b79
...
...
@@ -118,7 +118,7 @@ using DeviceGemmInstance =
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
true
>
;
//
OnlyLow
erTriangle
true
>
;
//
MaskOutUpp
erTriangle
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemmUpperTriangleMinusInf
<
ADataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
7ae26b79
...
...
@@ -117,7 +117,8 @@ using DeviceGemmInstance =
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
false
>
;
// MaskOutUpperTriangle
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
7ae26b79
...
...
@@ -168,7 +168,7 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
OnlyLow
erTriangle
=
false
,
bool
MaskOutUpp
erTriangle
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
...
...
@@ -500,7 +500,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
OnlyLow
erTriangle
>
;
MaskOutUpp
erTriangle
>
;
// Argument
// FIXME: constness
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
7ae26b79
...
...
@@ -77,7 +77,7 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
bool
PadN
,
bool
OnlyLow
erTriangle
=
false
>
bool
MaskOutUpp
erTriangle
>
struct
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
...
...
@@ -767,7 +767,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t
gemm1_k_block_outer_index
=
0
;
do
{
if
constexpr
(
OnlyLow
erTriangle
)
if
constexpr
(
MaskOutUpp
erTriangle
)
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
...
...
@@ -792,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf
,
num_k_block_main_loop
);
if
constexpr
(
!
OnlyLow
erTriangle
)
if
constexpr
(
!
MaskOutUpp
erTriangle
)
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment