Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a045e0be
Commit
a045e0be
authored
Feb 28, 2023
by
aska-0096
Browse files
Example branch provide to compiler team
parent
fbc576b5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1 addition
and
2137 deletions
+1
-2137
CMakeLists.txt
CMakeLists.txt
+0
-1
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+0
-8
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+0
-162
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+0
-771
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+0
-1194
No files found.
CMakeLists.txt
View file @
a045e0be
...
...
@@ -240,7 +240,6 @@ include_directories(BEFORE
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Weverything
)
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
...
...
example/01_gemm/gemm_wmma_fp16.cpp
View file @
a045e0be
...
...
@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
32
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
...
...
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
a045e0be
...
...
@@ -5,9 +5,6 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable
(
example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_example_executable
(
example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
)
endif
()
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
@@ -17,8 +14,3 @@ add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_soft
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_custom_target
(
example_gemm_scale_softmax_gemm_wmma
)
add_dependencies
(
example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16
)
endif
()
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
deleted
100644 → 0
View file @
fbc576b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
Acc0DataType
=
F32
;
using
Acc1DataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimM
=
1
;
static
constexpr
ck
::
index_t
NumDimN
=
1
;
static
constexpr
ck
::
index_t
NumDimK
=
1
;
static
constexpr
ck
::
index_t
NumDimO
=
1
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
static
constexpr
auto
TensorSpecA
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
Acc0BiasDataType
,
Acc0DataType
,
Acc1BiasDataType
,
Acc1DataType
,
CShuffleDataType
,
CDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
TensorSpecA
,
TensorSpecB0
,
TensorSpecB1
,
TensorSpecC
,
256
,
128
,
// MPerBlock
128
,
// LPerBlock
4
,
// K0PerBlock
8
,
// K1
64
,
// NPerBlock
4
,
// L0PerBlock
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
8
,
// LRepeat
4
,
// NRepeat
S
<
4
,
64
,
1
>
,
// ABlockTransfer MK -> K0 M K1
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// B0BlockTransfer LK -> K0 L K1
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
8
,
8
>
,
// B1BlockTransfer LN -> L0 N L1
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
1
,
false
,
1
,
// CShuffleMWmmaPerWavePerShuffle
2
,
// CShuffleNWmmaPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
>
;
// MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B0DataType
,
Acc0DataType
,
Acc1DataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
>
;
// Ref Softmax: fp32 in, fp16 out
using
ReferenceSoftmaxInstance
=
ck
::
tensor_operation
::
host
::
ReferenceSoftmax
<
Acc0DataType
,
ADataType
,
Acc0DataType
>
;
// Ref Gemm1: fp16 in, fp16 out
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
ADataType
,
B1DataType
,
CDataType
,
Acc1DataType
,
AElementOp
,
B1ElementOp
,
CElementOp
>
;
#include "run_batched_gemm_scale_softmax_gemm_permute.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
deleted
100644 → 0
View file @
fbc576b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimL
,
index_t
NumDimK
,
index_t
NumDimN
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
Acc0BiasDataType
,
typename
Acc0DataType
,
typename
Acc1BiasDataType
,
typename
Acc1DataType
,
typename
CShuffleDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
B0Spec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
K0PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
ck
::
index_t
K1
,
//
ck
::
index_t
NPerBlock
,
ck
::
index_t
L0PerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
LPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
MRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
B0BlockTransferThreadClusterLengths_K0_L_K1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
ck
::
index_t
B0BlockTransferSrcVectorDim
,
ck
::
index_t
B0BlockTransferSrcScalarPerVector
,
ck
::
index_t
B0BlockTransferDstScalarPerVector_K1
,
bool
B0BlockLdsAddExtraL
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
ck
::
index_t
B1BlockTransferSrcVectorDim
,
ck
::
index_t
B1BlockTransferSrcScalarPerVector
,
ck
::
index_t
B1BlockTransferDstScalarPerVector_L1
,
bool
B1BlockLdsAddExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
MaskingSpecialization
MaskingSpec
,
ck
::
index_t
NumPrefetch
=
1
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
v1
>
struct
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimL
>
0
&&
NumDimK
>
0
&&
NumDimN
>
0
,
"Number of dimension must be greater than 0"
);
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static
constexpr
index_t
NumDimGemm0M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm0N
=
NumDimL
;
static
constexpr
index_t
NumDimGemm0K
=
NumDimK
;
static
constexpr
index_t
NumDimGemm1M
=
NumDimM
;
static
constexpr
index_t
NumDimGemm1N
=
NumDimN
;
static
constexpr
index_t
NumDimGemm1K
=
NumDimL
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
K1
;
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
>
,
GemmSpec
,
ASpec
,
B0Spec
,
B1Spec
,
CSpec
>
;
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
K1
>
{});
}
static
auto
MakeB0GridDescriptor_BK0_L_BK1
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
K1
>
{});
}
static
auto
MakeB1GridDescriptor_BL0_N_BL1
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor_BK0_L_BK1
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
return
MaskOutUpperTrianglePredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
const
AGridDesc_G_M_K
&
a_grid_desc_g_m_k
,
const
B0GridDesc_G_L_K
&
b0_grid_desc_g_l_k
,
const
B1GridDesc_G_N_L
&
b1_grid_desc_g_n_l
,
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
:
a_grid_desc_g_m_k_
(
a_grid_desc_g_m_k
),
b0_grid_desc_g_l_k_
(
b0_grid_desc_g_l_k
),
b1_grid_desc_g_n_l_
(
b1_grid_desc_g_n_l
),
c_grid_desc_g_m_n_
(
c_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
a_grid_desc_g_m_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB0BasePtr
(
index_t
g_idx
)
const
{
return
b0_grid_desc_g_l_k_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
b1_grid_desc_g_n_l_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
c_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
};
// GridwiseOp
using
GridwiseOp
=
GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
<
// DataType Family
ADataType
,
B0DataType
,
Acc0DataType
,
B1DataType
,
Acc1DataType
,
CShuffleDataType
,
CDataType
,
// ElementwiseOp Family
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
AGridDesc_AK0_M_AK1
,
B0GridDesc_BK0_L_BK1
,
B1GridDesc_BL0_N_BL1
,
CGridDesc_M_N
,
// Tiling Family
MPerBlock
,
LPerBlock
,
K0PerBlock
,
// K0 * K1 = Gemm0 GEMM_K Dim
K1
,
//
NPerBlock
,
L0PerBlock
,
L1
,
MPerWMMA
,
LPerWMMA
,
NPerWMMA
,
MRepeat
,
LRepeat
,
NRepeat
,
// ThreadCluster Family
BlockSize
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
true
,
ABlockLdsAddExtraM
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferSrcAccessOrder
,
B0BlockTransferSrcVectorDim
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_K1
,
true
,
B0BlockLdsAddExtraL
,
B1BlockTransferThreadClusterLengths_L0_N_L1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_L1
,
false
,
B1BlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
B0DataType
*
p_b0_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
const
index_t
M01
,
const
index_t
N01
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
Transform
::
MakeAGridDescriptor_G_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_g_l_k_
{
Transform
::
MakeB0GridDescriptor_G_N_K
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_g_n_l_
{
Transform
::
MakeB1GridDescriptor_G_N_K
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
)},
a_element_op_
{
a_element_op
},
b0_element_op_
{
b0_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
c0_matrix_mask_
{
b0_grid_desc_g_l_k_
.
GetLength
(
I1
)},
raw_lengths_mz_lz_kz_nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_lengths
[
NumDimG
+
NumDimL
+
NumDimK
-
1
],
b1_gs_ns_ls_lengths
[
NumDimG
+
NumDimN
-
1
]},
a_mz_kz_strides_
{
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
],
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
]},
b0_lz_kz_strides_
{
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
-
1
],
b0_gs_ls_ks_strides
[
NumDimG
+
NumDimL
+
NumDimK
-
1
]},
b1_nz_lz_strides_
{
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
-
1
],
b1_gs_ns_ls_strides
[
NumDimG
+
NumDimN
+
NumDimL
-
1
]},
c_mz_nz_strides_
{
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
compute_ptr_offset_of_batch_
{
a_grid_desc_g_m_k_
,
b0_grid_desc_g_l_k_
,
b1_grid_desc_g_n_l_
,
c_grid_desc_g_m_n_
}
{
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ls_lengths
;
ignore
=
acc0_biases_gs_ms_ls_strides
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b0_grid_desc_bk0_l_bk1_
,
b1_grid_desc_bl0_n_bl1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// Pointers
const
ADataType
*
p_a_grid_
;
const
B0DataType
*
p_b0_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
// Tensor Descriptors
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
B0GridDesc_G_L_K
b0_grid_desc_g_l_k_
;
B1GridDesc_G_N_L
b1_grid_desc_g_n_l_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
// Block to Tile mapping
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// ElementwiseOp
AElementwiseOperation
a_element_op_
;
B0ElementwiseOperation
b0_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
// check C0 masking and padding
C0MatrixMask
c0_matrix_mask_
;
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
std
::
vector
<
index_t
>
raw_lengths_mz_lz_kz_nz_
;
std
::
vector
<
index_t
>
a_mz_kz_strides_
;
std
::
vector
<
index_t
>
b0_lz_kz_strides_
;
std
::
vector
<
index_t
>
b1_nz_lz_strides_
;
std
::
vector
<
index_t
>
c_mz_nz_strides_
;
index_t
batch_count_
;
// Batch Offset
ComputeBasePtrOfStridedBatch
compute_ptr_offset_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
GridwiseOp
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
B0GridDesc_BK0_L_BK1
,
DeviceOp
::
B1GridDesc_BL0_N_BL1
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
B0ElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
ComputeBasePtrOfStridedBatch
,
C0MatrixMask
,
typename
GridwiseOp
::
DefaultBlock2CTileMap
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b0_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
batch_count_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
return
false
;
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
return
false
;
}
}
else
{
return
false
;
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
return
false
;
}
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_bl0_n_bl1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_n
==
b1_n
))
{
return
false
;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const
auto
MzRaw
=
arg
.
raw_lengths_mz_lz_kz_nz_
[
0
];
const
auto
LzRaw
=
arg
.
raw_lengths_mz_lz_kz_nz_
[
1
];
const
auto
KzRaw
=
arg
.
raw_lengths_mz_lz_kz_nz_
[
2
];
const
auto
NzRaw
=
arg
.
raw_lengths_mz_lz_kz_nz_
[
3
];
// Check scalar per vector requirement
const
auto
a_extent_lowest
=
ABlockTransferSrcVectorDim
==
2
?
KzRaw
:
MzRaw
;
const
auto
b0_extent_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
KzRaw
:
LzRaw
;
const
auto
b1_extent_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
LzRaw
:
NzRaw
;
const
auto
c_extent_lowest
=
NzRaw
;
if
(
!
(
a_extent_lowest
%
ABlockTransferSrcScalarPerVector
==
0
&&
b0_extent_lowest
%
B0BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
// Check vector load/store requirement
const
auto
a_stride_lowest
=
ABlockTransferSrcVectorDim
==
2
?
arg
.
a_mz_kz_strides_
[
1
]
:
arg
.
a_mz_kz_strides_
[
0
];
const
auto
b0_stride_lowest
=
B0BlockTransferSrcVectorDim
==
2
?
arg
.
b0_lz_kz_strides_
[
1
]
:
arg
.
b0_lz_kz_strides_
[
0
];
const
auto
b1_stride_lowest
=
B1BlockTransferSrcVectorDim
==
2
?
arg
.
b1_nz_lz_strides_
[
1
]
:
arg
.
b1_nz_lz_strides_
[
0
];
const
auto
c_stride_lowest
=
arg
.
c_mz_nz_strides_
[
1
];
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
B0DataType
*
p_b0
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b0
,
p_b1
,
p_c
,
p_acc0_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
,
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
,
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
c_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ls_strides
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
B0DataType
*>
(
p_b0
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
p_acc0_biases
,
p_acc1_biases
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
,
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
,
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
,
acc0_biases_gs_ms_ls_lengths
,
acc0_biases_gs_ms_ls_strides
,
acc1_biases_gs_ms_ns_lengths
,
acc1_biases_gs_ms_ns_strides
,
1
,
1
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
L0PerBlock
<<
", "
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
B0Spec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
deleted
100644 → 0
View file @
fbc576b5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB0
,
typename
FloatB1
,
typename
FloatC
,
typename
AGridDesc_AK0_M_AK1
,
typename
B0GridDesc_BK0_L_BK1
,
typename
B1GridDesc_BL0_N_BL1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeBasePtrOfStridedBatch
,
typename
C0MatrixMask
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
B0ElementwiseOperation
b0_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b0_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB0BasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc_ak0_m_ak1
,
b0_grid_desc_bk0_l_bk1
,
b1_grid_desc_l0_n_l1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b0_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
c0_matrix_mask
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b0_grid_desc_bk0_l_bk1
;
ignore
=
b1_grid_desc_l0_n_l1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b0_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template
<
typename
FloatA
,
typename
FloatB0
,
typename
FloatAcc0
,
typename
FloatB1
,
typename
FloatAcc1
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
B0GridDesc_BK0_L_BK1
,
typename
B1GridDesc_BL0_N_BL1
,
typename
CGridDesc_M_N
,
index_t
MPerBlock
,
index_t
LPerBlock
,
index_t
K0PerBlock
,
// K0 * K1Value = Gemm0 GEMM_K Dim
index_t
K1Value
,
index_t
NPerBlock
,
index_t
L0PerBlock
,
index_t
L1Value
,
index_t
MPerWmma
,
index_t
LPerWmma
,
index_t
NPerWmma
,
index_t
MRepeat
,
index_t
LRepeat
,
index_t
NRepeat
,
index_t
BlockSize
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
B0BlockTransferThreadClusterLengths_K0_L_K1
,
typename
B0BlockTransferThreadClusterArrangeOrder
,
typename
B0BlockTransferSrcAccessOrder
,
index_t
B0BlockTransferSrcVectorDim
,
index_t
B0BlockTransferSrcScalarPerVector
,
index_t
B0BlockTransferDstScalarPerVector_K1
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
bool
B0BlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_L1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
PadN
,
bool
MaskOutUpperTriangle
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
AK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
AK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
K0PerBlock
>
{};
static
constexpr
auto
BK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
AL0
=
Number
<
L0PerBlock
/
2
>
{};
static
constexpr
auto
AL1
=
Number
<
L1Value
>
{};
static
constexpr
auto
BL0
=
Number
<
L0PerBlock
>
{};
static
constexpr
auto
BL1
=
Number
<
L1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
A0BlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeA0BlockDescriptor_K0_M0_M1_M2_K1
(
const
A0BlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
A_K0
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
A0BlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
return
transform_tensor_descriptor
(
A0BlockDesc_AK0_M_AK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
B0BlockDesc_BK0_L_BK1
>
__host__
__device__
static
constexpr
auto
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
const
B0BlockDesc_BK0_L_BK1
&
)
{
constexpr
index_t
B_K0
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
return
transform_tensor_descriptor
(
B0BlockDesc_BK0_L_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
A1BlockDesc_AL0_M_AL1
>
__host__
__device__
static
constexpr
auto
MakeA1BlockDescriptor_L0_M0_M1_M2_L1
(
const
A1BlockDesc_AL0_M_AL1
&
)
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_L0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
template
<
typename
B1BlockDesc_BL0_N_BL1
>
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
B1BlockDesc_BL0_N_BL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
LPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
LPerBlock
+
B0BlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BL0
,
Number
<
NPerBlock
>
{},
BL1
),
make_tuple
(
Number
<
NPerBlock
+
B1BlockLdsExtraN
>
{}
*
BL1
,
BL1
,
I1
));
}
__host__
__device__
static
constexpr
auto
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
)
*
sizeof
(
FloatB1
);
const
index_t
softmax_bytes_end
=
(
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
)
*
sizeof
(
FloatAcc0
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerWmma
*
MRepeat
)
==
0
)
&&
(
LPerBlock
%
(
LPerWmma
*
LRepeat
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
L
=
b0_grid_desc_bk0_l_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I1
);
const
auto
KPerBlock
=
K0PerBlock
*
K1Value
;
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
{
return
false
;
}
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
LPerBlock
%
(
L0PerBlock
*
L1Value
)
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
LPerBlock
/
(
L0PerBlock
*
L1Value
);
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1Value
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
/* M01 */
,
index_t
/* N01 */
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
static
constexpr
auto
b0_block_desc_bk0_l_bk1
=
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
();
static
constexpr
auto
b1_block_desc_bl0_n_bl1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
BL1
);
static
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b0_block_space_size_aligned
=
math
::
integer_least_multiple
(
b0_block_desc_bk0_l_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
b1_block_space_size_aligned
=
math
::
integer_least_multiple
(
b1_block_desc_bl0_n_bl1
.
GetElementSpaceSize
(),
max_lds_align
);
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
static
constexpr
auto
c_block_space_size
=
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_k0_m_k1
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_k0_l_k1
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
B0ElementwiseOperation
&
b0_element_op
,
const
AccElementwiseOperation
&
acc_element_op
,
const
B1ElementwiseOperation
&
b1_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// clang-format off
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc_k0_l_k1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_l0_n_l1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// Store BlockId into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
/*******************************************************************************/
// set up Gemm0
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b0_block_desc_k0perblock_lperblock_k1
=
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
/* index_t DstVectorDim, */
2
,
/* index_t SrcScalarPerVector, */
ABlockTransferSrcScalarPerVector
,
/* index_t DstScalarPerVector, */
ABlockTransferDstScalarPerVector_K1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0perblock_mperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
LPerBlock
,
BK1
>
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
FloatB0
,
FloatB0
,
decltype
(
b0_grid_desc_k0_l_k1
),
decltype
(
b0_block_desc_k0perblock_lperblock_k1
),
B0BlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
B0BlockTransferSrcVectorDim
,
2
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_K1
,
1
,
1
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b0_grid_desc_k0_l_k1
,
make_multi_index
(
0
,
0
,
0
),
b0_element_op
,
b0_block_desc_k0perblock_lperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
/*******************************************************************************/
// Gemm0
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
FloatB0
,
FloatAcc0
,
decltype
(
MakeA0BlockDescriptor_K0_M0_M1_M2_K1
(
a_block_desc_k0perblock_mperblock_k1
)),
decltype
(
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
b0_block_desc_k0perblock_lperblock_k1
)),
MPerBlock
,
LPerBlock
,
K0PerBlock
*
K1Value
,
MPerWmma
,
LPerWmma
,
MRepeat
,
LRepeat
,
KPack
,
true
>
{};
// C' = B' x A'
// Prepare Register for A*B0 matrix
auto
acc0_thread_buf
=
blockwise_gemm0
.
GetCThreadBuffer
();
constexpr
auto
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
constexpr
auto
mrepeat
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I0
);
constexpr
auto
mwave
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I1
);
constexpr
auto
mthreadpersubgroup
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I2
);
constexpr
auto
lrepeat
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I3
);
constexpr
auto
lwave
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I4
);
constexpr
auto
lsubgroup
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I5
);
constexpr
auto
laccvgprs
=
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I6
);
constexpr
auto
acc0_thread_desc_l0perblock_mperblock_l1
=
transform_tensor_descriptor
(
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
lrepeat
,
lwave
,
lsubgroup
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
mrepeat
,
mwave
,
mthreadpersubgroup
)),
make_pass_through_transform
(
laccvgprs
)),
make_tuple
(
Sequence
<
3
,
4
,
5
>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
/*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB0
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
b0_block_desc_k0perblock_lperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
a_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b0_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
),
LPerBlock
,
0
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
/*******************************************************************************/
// softmax
/*******************************************************************************/
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAcc0
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 7D thread cluster
constexpr
auto
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm0
.
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
()
/
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
constexpr
auto
t_mrepeat
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I0
);
constexpr
auto
t_mwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I1
);
constexpr
auto
t_mthreadpersubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I2
);
constexpr
auto
t_lrepeat
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I3
);
constexpr
auto
t_lwave
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I4
);
constexpr
auto
t_lsubgroup
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I5
);
constexpr
auto
t_laccvgprs
=
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
At
(
I6
);
// get acc0 thread map
constexpr
auto
m0_l_m1_to_m_l_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
t_mrepeat
*
t_mwave
,
t_mthreadpersubgroup
)),
make_pass_through_transform
(
I1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
threadid_to_m0_l_m1_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
t_mrepeat
*
t_mwave
,
t_lrepeat
*
t_lwave
*
t_lsubgroup
*
t_laccvgprs
,
t_mthreadpersubgroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
threadid_to_l_n_thread_cluster_adaptor
=
chain_tensor_adaptors
(
m0_l_m1_to_m_l_adaptor
,
threadid_to_m0_l_m1_adaptor
);
// get acc0 2D thread cluster & 2D thread slice
constexpr
auto
thread_cluster_desc_m_l
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
t_mrepeat
*
t_mwave
*
t_mthreadpersubgroup
,
t_lrepeat
*
t_lwave
*
t_lsubgroup
*
t_laccvgprs
));
constexpr
auto
thread_slice_desc_m_l
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
mrepeat
*
mwave
*
mthreadpersubgroup
,
lrepeat
*
lwave
*
lsubgroup
*
laccvgprs
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatAcc0
,
decltype
(
threadid_to_l_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_l
),
decltype
(
thread_slice_desc_m_l
)
>
{};
// Initialize running sum and max of exponentiating row vectors
using
SoftmaxBuf
=
typename
decltype
(
blockwise_softmax
)
::
BufferType
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
FloatAcc0
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatAcc0
>::
Lowest
();
/*******************************************************************************/
// set up Gemm1
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
BL0
,
0
,
0
);
// A1 matrix in VGPR
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
Number
<
AL0
*
AL1
/
laccvgprs
>
{},
Number
<
mrepeat
*
mwave
*
mthreadpersubgroup
>
{},
Number
<
laccvgprs
>
{});
// Data duplicated dimension
constexpr
auto
A1ThreadSliceL0PerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I0
];
constexpr
auto
A1ThreadSliceMPerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I1
];
constexpr
auto
A1ThreadSliceL1
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I2
];
// A1 has duplicated data
constexpr
auto
A1ThreadDuplicatedDim
=
I2
*
A1ThreadSliceL1
;
constexpr
auto
a1_thread_desc_l0perblock_mperblock_l1
=
make_naive_tensor_descriptor
(
make_tuple
(
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1ThreadDuplicatedDim
),
make_tuple
(
A1ThreadSliceMPerBlock
*
A1ThreadDuplicatedDim
,
A1ThreadDuplicatedDim
,
I1
));
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatAcc0
,
FloatA
,
decltype
(
acc0_thread_desc_l0perblock_mperblock_l1
),
decltype
(
a1_thread_desc_l0perblock_mperblock_l1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1ThreadSliceL1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
laccvgprs
,
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210
,
0xfedcba98
,
false
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
B1ElementwiseOperation
,
/* typename DstElementwiseOperation, */
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
BL0
,
NPerBlock
,
BL1
>
,
/* typename ThreadClusterLengths, */
B1BlockTransferThreadClusterLengths_L0_N_L1
,
/* typename ThreadClusterArrangeOrder, */
B1BlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatB1
,
/* typename DstData, */
FloatB1
,
/* typename SrcDesc, */
decltype
(
b1_grid_desc_l0_n_l1
),
/* typename DstDesc, */
decltype
(
b1_block_desc_l0perblock_nperblock_l1
),
/* typename SrcDimAccessOrder, */
B1BlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
1
,
0
,
2
>
,
/* index_t SrcVectorDim, */
B1BlockTransferSrcVectorDim
,
/* index_t DstVectorDim, */
2
,
/* index_t SrcScalarPerVector, */
B1BlockTransferSrcScalarPerVector
,
/* index_t DstScalarPerVector, */
B1BlockTransferDstScalarPerVector_L1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
B1ThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1_grid_desc_l0_n_l1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_l0perblock_nperblock_l1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a1_thread_desc_l0perblock_mperblock_l1
.
GetElementSpaceSize
());
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB1
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_l0perblock_nperblock_l1
.
GetElementSpaceSize
());
auto
blockwise_gemm1
=
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
FloatB1
,
FloatAcc1
,
decltype
(
MakeA1BlockDescriptor_L0_M0_M1_M2_L1
(
a1_thread_desc_l0perblock_mperblock_l1
)),
decltype
(
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
b1_block_desc_l0perblock_nperblock_l1
)),
MPerBlock
,
NPerBlock
,
BL0
*
BL1
,
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_l_block_outer_loop
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
)
/
LPerBlock
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
(
BL0
*
BL1
);
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc1
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
/*******************************************************************************/
// Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t
gemm1_l_block_outer_index
=
0
;
// Outer loop, along GEMM_L
// Inner loop, along GEMM_K
do
{
auto
l_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_l_block_outer_index
*
LPerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
l_block_data_idx_on_grid
,
MPerBlock
,
LPerBlock
))
{
continue
;
}
// gemm0 start, A-B swaped
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0perblock_mperblock_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b0_grid_desc_k0_l_k1
,
b0_block_desc_k0perblock_lperblock_k1
,
b0_blockwise_copy
,
b0_grid_buf
,
b0_block_buf
,
b0_block_slice_copy_step
,
blockwise_gemm0
,
acc0_thread_buf
,
K0BlockMainLoop
);
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
{
// 7d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm0
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
// 7d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm0
.
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
().
GetLengths
();
constexpr
auto
MREPEAT
=
c_block_lengths
[
I0
];
constexpr
auto
MWAVE
=
c_block_lengths
[
I1
];
constexpr
auto
MTHREADSubGroup
=
c_block_lengths
[
I2
];
constexpr
auto
LREPEAT
=
c_block_lengths
[
I3
];
constexpr
auto
LWAVE
=
c_block_lengths
[
I4
];
constexpr
auto
LSUBGROUP
=
c_block_lengths
[
I5
];
constexpr
auto
LACCVGPRS
=
c_block_lengths
[
I6
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
auto
acc0_thread_origin
=
blockwise_gemm0
.
CalculateCThreadOriginDataIndex7D
(
Number
<
0
>
{},
Number
<
0
>
{});
constexpr
auto
block_idx_to_m_l_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MREPEAT
,
MWAVE
,
MTHREADSubGroup
)),
make_unmerge_transform
(
make_tuple
(
LREPEAT
,
LWAVE
,
LSUBGROUP
,
LACCVGPRS
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
static_for
<
0
,
Acc0TileIterator
::
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
i
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_l_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
l_local
=
block_idx_to_m_l_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
l_global
=
l_local
+
l_block_data_idx_on_grid
;
if
(
c0_matrix_mask
.
IsMaskedElement
(
m_global
,
l_global
))
{
acc0_thread_buf
(
i
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
acc_element_op
(
acc0_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
}
});
}
else
{
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
acc_element_op
(
acc0_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
block_sync_lds
();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
SoftmaxBuf
&
sum
=
blockwise_softmax
.
sum_value_buf
;
blockwise_softmax
.
Run
(
acc0_thread_buf
,
workspace_buf
);
// TODO: may convert to log domain
running_max_new
=
mathext
::
max
(
max
,
running_max
);
running_sum_new
=
mathext
::
exp
(
running_max
-
running_max_new
)
*
running_sum
+
mathext
::
exp
(
max
-
running_max_new
)
*
sum
;
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// Initialize acc1
acc1_thread_buf
.
Clear
();
// preload data into LDS
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_l0_n_l1
,
b1_grid_buf
);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_l0_n_l1
,
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for reduction LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_l0perblock_nperblock_l1
,
b1_block_buf
);
// main body
if
constexpr
(
num_gemm1_l_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_l_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
// Data cast from FloatAcc0 to FloatA happen here
a1_blockwise_copy
.
Run
(
acc0_thread_desc_l0perblock_mperblock_l1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceL0PerBlock
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_l0perblock_mperblock_l1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_l0_n_l1
,
b1_grid_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
block_sync_lds
();
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_l0_n_l1
,
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_l0perblock_nperblock_l1
,
b1_block_buf
);
});
}
// tail
{
a1_blockwise_copy
.
Run
(
acc0_thread_desc_l0perblock_mperblock_l1
,
make_tuple
(
Number
<
(
num_gemm1_l_block_inner_loop
-
1
)
*
A1ThreadSliceL0PerBlock
>
{},
I0
,
I0
),
acc0_thread_buf
,
a1_thread_desc_l0perblock_mperblock_l1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
}
}
// end gemm1
constexpr
auto
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm1
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
constexpr
auto
c_mrepeat
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I0
);
constexpr
auto
c_mwave
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I1
);
constexpr
auto
c_mthreadpersubgroup
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I2
);
constexpr
auto
c_nrepeat
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I3
);
constexpr
auto
c_nwave
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I4
);
constexpr
auto
c_nsubgroup
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I5
);
constexpr
auto
c_naccvgprs
=
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
.
GetLength
(
I6
);
constexpr
auto
c_thread_slice_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
c_mrepeat
*
c_mwave
*
c_mthreadpersubgroup
,
c_nrepeat
*
c_nwave
*
c_nsubgroup
*
c_naccvgprs
));
constexpr
auto
c_thread_buf_slice_m
=
c_thread_slice_desc_m_n
.
GetLength
(
I0
);
constexpr
auto
c_thread_buf_slice_n
=
c_thread_slice_desc_m_n
.
GetLength
(
I1
);
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
FloatAcc1
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatAcc1
c
=
c_thread_buf
[
I
];
// O
FloatAcc1
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
c_thread_buf
(
I
)
=
c_new
;
// O_new
});
});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_reset_copy_step
);
// rewind K
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc_k0_l_k1
,
b0_block_reset_copy_step
);
// rewind K and step N
// update before next j iteration
running_max
=
running_max_new
;
running_sum
=
running_sum_new
;
block_sync_lds
();
// wait for gemm1 LDS read
}
while
(
++
gemm1_l_block_outer_index
<
num_gemm1_l_block_outer_loop
);
/*******************************************************************************/
// write out to C, implement shuffle
{
constexpr
auto
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
blockwise_gemm1
.
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
=
blockwise_gemm1
.
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
.
GetLength
(
I1
);
constexpr
auto
MThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
.
GetLength
(
I4
);
constexpr
auto
NSubGroup
=
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
NAccVgprs
=
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp
.
GetLength
(
I6
);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMRepeatPerShuffle
>
{},
// MRepeat per shuffle repeat
MWave
,
// MWave
MThreadPerSubGroup
// MThreadPerSubGroup = MPerWmma
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNRepeatPerShuffle
>
{},
// NRepeat per shuffle repeat
NWave
,
// NWave
NSubGroup
,
NAccVgprs
))),
// NSubGroup * NAccVgprs = NPerWmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
,
2
>
{},
Sequence
<>
{},
Sequence
<
3
,
4
,
5
,
6
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm1
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
MThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
NSubGroup
,
NAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc1
,
FloatCShuffle
,
decltype
(
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
I1
,
I1
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
NAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
8
,
// vector write pixel
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
make_multi_index
(
0
,
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
0
,
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I2
],
n_thread_data_on_block_idx
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for local reg & global memory
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MRepeat
,
1
,
1
,
NRepeat
,
1
,
1
,
NAccVgprs
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
Sequence
<
CShuffleMRepeatPerShuffle
,
1
,
1
,
CShuffleNRepeatPerShuffle
,
1
,
1
,
NAccVgprs
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
// clang-format on
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment