Commit 6dbced07 authored by letaoqin's avatar letaoqin
Browse files

change mha infer class and file name

parent 63ea1d70
add_example_executable(example_batched_multihead_attention_forward batched_gemm_multihead_attention_forward.cpp) add_example_executable(example_batched_multihead_attention_infer batched_gemm_multihead_attention_infer.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward batched_gemm_multihead_attention_bias_forward.cpp) add_example_executable(example_batched_multihead_attention_bias_infer batched_gemm_multihead_attention_bias_infer.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward grouped_mutihead_attention_bias_forward.cpp) add_example_executable(example_grouped_multihead_attention_bias_infer grouped_mutihead_attention_bias_infer.cpp)
add_example_executable(example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp) add_example_executable(example_batched_multihead_attention_bias_forward_v2 batched_multihead_attention_bias_forward_v2.cpp)
add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp) add_example_executable(example_grouped_multihead_attention_bias_forward_v2 grouped_multihead_attention_bias_forward_v2.cpp)
......
...@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl< using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -18,7 +18,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -67,7 +67,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl< using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -66,7 +66,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -66,7 +66,8 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl< using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -44,7 +44,7 @@ __global__ void ...@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_multiple_head_flash_attention_forward( kernel_batched_multiple_head_flash_attention_infer(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const D0DataType* p_d0_grid, const D0DataType* p_d0_grid,
...@@ -205,7 +205,7 @@ template <index_t NumDimG, ...@@ -205,7 +205,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
int D0sTransferSrcScalarPerVector = 4, int D0sTransferSrcScalarPerVector = 4,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG, : public DeviceBatchedMultiheadAttentionInfer<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -243,7 +243,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl; using DeviceOp = DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N d0_grid_desc_g_m_n_; D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
}; };
using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
D0DataType, D0DataType,
GemmAccDataType, GemmAccDataType,
...@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multiple_head_flash_attention_forward< const auto kernel = kernel_batched_multiple_head_flash_attention_infer<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
D0DataType, D0DataType,
...@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -925,7 +925,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl" str << "DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_mha_infer.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_fwd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_mha_infer_xdl_cshuffle.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -35,7 +35,7 @@ __global__ void ...@@ -35,7 +35,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_grouped_multiple_head_flash_attention_infer(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -194,7 +194,7 @@ template <index_t NumDimG, ...@@ -194,7 +194,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceGroupedMultiheadAttentionInfer<NumDimG, : public DeviceGroupedMultiheadAttentionInfer<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -230,7 +230,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl; using DeviceOp = DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionInfer<NumDimG, using ProblemDesc = typename DeviceGroupedMultiheadAttentionInfer<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -382,7 +382,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
Acc0BiasDataType, Acc0BiasDataType,
GemmAccDataType, GemmAccDataType,
...@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -698,7 +698,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm, kernel_grouped_multiple_head_flash_attention_infer<GridwiseGemm,
D0DataType, D0DataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl ...@@ -944,7 +944,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl" str << "DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -86,7 +86,7 @@ template <typename FloatAB, ...@@ -86,7 +86,7 @@ template <typename FloatAB,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{ {
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment