Commit c703632f authored by letaoqin's avatar letaoqin
Browse files

fix class name

parent 6bc73d41
...@@ -44,7 +44,7 @@ __global__ void ...@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_mutiple_head_flash_attention_forward( kernel_batched_multiple_head_flash_attention_forward(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const D0DataType* p_d0_grid, const D0DataType* p_d0_grid,
...@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N d0_grid_desc_g_m_n_; D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
}; };
using GridwiseGemm = GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle< using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
D0DataType, D0DataType,
GemmAccDataType, GemmAccDataType,
...@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl ...@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_mutiple_head_flash_attention_forward< const auto kernel = kernel_batched_multiple_head_flash_attention_forward<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
D0DataType, D0DataType,
......
...@@ -86,7 +86,7 @@ template <typename FloatAB, ...@@ -86,7 +86,7 @@ template <typename FloatAB,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle struct GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle
{ {
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment