Commit 6dbced07 authored by letaoqin's avatar letaoqin
Browse files

change mha infer class and file name

parent 63ea1d70
add_example_executable(example_batched_multihead_attention_forward batched_gemm_multihead_attention_forward.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward batched_gemm_multihead_attention_bias_forward.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward grouped_mutihead_attention_bias_forward.cpp)
add_example_executable(example_batched_multihead_attention_infer batched_gemm_multihead_attention_infer.cpp)
add_example_executable(example_batched_multihead_attention_bias_infer batched_gemm_multihead_attention_bias_infer.cpp)
add_example_executable(example_grouped_multihead_attention_bias_infer grouped_mutihead_attention_bias_infer.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp)
......
......@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
DIM, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
DIM / 32, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
DIM, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
DIM / 32, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
......@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -67,72 +67,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
DIM, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
DIM / 32, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
DIM, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
DIM / 32, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
......@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -66,72 +66,73 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
4,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
......
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_multiple_head_flash_attention_forward(
kernel_batched_multiple_head_flash_attention_infer(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const D0DataType* p_d0_grid,
......@@ -205,7 +205,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl
struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG,
NumDimM,
NumDimN,
......@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl;
using DeviceOp = DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
};
using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
D0DataType,
GemmAccDataType,
......@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multiple_head_flash_attention_forward<
const auto kernel = kernel_batched_multiple_head_flash_attention_infer<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
D0DataType,
......@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl"
str << "DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -35,7 +35,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
kernel_grouped_multiple_head_flash_attention_infer(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -194,7 +194,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl
struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceGroupedMultiheadAttentionInfer<NumDimG,
NumDimM,
NumDimN,
......@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl;
using DeviceOp = DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionInfer<NumDimG,
NumDimM,
NumDimN,
......@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
};
// GridwiseGemm
using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType,
GemmAccDataType,
......@@ -698,15 +698,15 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
D0DataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
kernel_grouped_multiple_head_flash_attention_infer<GridwiseGemm,
D0DataType,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
......@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl"
str << "DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -86,7 +86,7 @@ template <typename FloatAB,
bool PadN,
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment