Commit c703632f authored by letaoqin's avatar letaoqin
Browse files

fix class name

parent 6bc73d41
......@@ -44,7 +44,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_mutiple_head_flash_attention_forward(
kernel_batched_multiple_head_flash_attention_forward(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const D0DataType* p_d0_grid,
......@@ -376,7 +376,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
};
using GridwiseGemm = GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle<
using GridwiseGemm = GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
D0DataType,
GemmAccDataType,
......@@ -641,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_mutiple_head_flash_attention_forward<
const auto kernel = kernel_batched_multiple_head_flash_attention_forward<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
D0DataType,
......
......@@ -86,7 +86,7 @@ template <typename FloatAB,
bool PadN,
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMutiHeadFlashAttentionForward_Xdl_CShuffle
struct GridwiseMultiHeadFlashAttentionForward_Xdl_CShuffle
{
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment