Unverified Commit efee4541 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

AIT Attention API refactor (#8)

* sanity pass

* sanity pass 2

* confirm significant performance regression.

* turn on all instances

* turn off instance format

* Fix bug & tunning & format

* DML meta, self_attn+cross_attn

* sanity pass

* remove useless flag

* update tile and problem size used in AIT attention

* bug fix in grouped conv supporting check
parent bee4e344
...@@ -8,6 +8,8 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe ...@@ -8,6 +8,8 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
endif() endif()
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
......
...@@ -67,77 +67,200 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial ...@@ -67,77 +67,200 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance = // clang-format off
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< // #define CK_MHA_USE_WAVE_1
NumDimG, // #define CK_MHA_USE_WAVE_2
NumDimM, #define CK_MHA_USE_WAVE_4
NumDimN, #define CK_MHA_USE_WAVE_8
NumDimK, using DeviceMHAFactory =
NumDimO, std::tuple<
ADataType, #ifdef CK_MHA_USE_WAVE_1
B0DataType, // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
B1DataType, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
CDataType, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
Acc0BiasDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
Acc0DataType, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
Acc1BiasDataType, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
Acc1DataType, 32,
CShuffleDataType, // Gemm 0
AElementOp, 16, 128, 64, 8, 8,
B0ElementOp, // Gemm 1
Acc0ElementOp, 64, 64, 8,
B1ElementOp, 16, 16, 16,
CElementOp, // Per repeat = wave_m = wave_num, wave_n = 1
GemmSpec, 1, 8, 4,
TensorSpecA, // ABlockTransfer MK -> K0 M K1
TensorSpecB0, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
TensorSpecB1, // B0BlockTransfer LK -> K0 L K1
TensorSpecC, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
1, // B1BlockTransfer NL -> L0 N L1
256, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// Gemm 0 // CShuffleBlockTransfer MN
128, // MPerBlock 1, 1, S<1, 16, 1, 2>, 8,
64, // LPerBlock MaskingSpec>,
64, // KPerBlock ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
8, // K1 NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
// Gemm 1 ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
64, // NPerBlock AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
64, // LTilePerBlock GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
8, // L1 32,
16, // MPerWMMA // Gemm 0
16, // LPerWMMA 16, 64, 64, 8, 8,
16, // NPerWMMA // Gemm 1
// Per repeat = wave_m = wave_num, wave_n = 1 64, 64, 8,
1, // MRepeat 16, 16, 16,
4, // LRepeat // Per repeat = wave_m = wave_num, wave_n = 1
4, // NRepeat 1, 4, 4,
S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<1, 0, 2>, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<1, 0, 2>, // B0BlockTransfer LK -> K0 L K1
2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
8, // B1BlockTransfer NL -> L0 N L1
8, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
true, // CShuffleBlockTransfer MN
S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1 1, 1, S<1, 16, 1, 2>, 8,
S<1, 0, 2>, MaskingSpec>,
S<1, 0, 2>, #endif
2, #ifdef CK_MHA_USE_WAVE_2
8, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
8, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
true, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1 AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
S<0, 2, 1>, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
S<0, 2, 1>, 64,
1, // Gemm 0
8, 32, 128, 64, 8, 8,
1, // Gemm 1
false, 64, 64, 8,
1, // CShuffleMWmmaPerWavePerShuffle 16, 16, 16,
2, // CShuffleNWmmaPerWavePerShuffle // Per repeat = wave_m = wave_num, wave_n = 1
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 1, 8, 4,
8, // CShuffleBlockTransferScalarPerVector_NPerBlock // ABlockTransfer MK -> K0 M K1
MaskingSpec>; // MaskingSpecialization S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType, B0DataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 192, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 192, 48, 8,4,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
Acc0DataType,
Acc1DataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
Acc1DataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_cross_attention.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -127,16 +127,31 @@ int run(int argc, char* argv[]) ...@@ -127,16 +127,31 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
case 6: // Rand: a b0 ; unit: b1 pass case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break; break;
case 7: // Rand: a b1 ; unit: b0 pass case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break; break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
...@@ -160,139 +175,166 @@ int run(int argc, char* argv[]) ...@@ -160,139 +175,166 @@ int run(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{}; ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
auto invoker = gemm.MakeInvoker(); const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), auto gemm = DeviceMHAInstance{};
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), auto invoker = gemm.MakeInvoker();
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc1_biases; static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
a_gs_ms_ks_lengths, static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
a_gs_ms_ks_strides, M,
b0_gs_ns_ks_lengths, N,
b0_gs_ns_ks_strides, K,
b1_gs_os_ns_lengths, O,
b1_gs_os_ns_strides, G0,
c_gs_ms_os_lengths, G1,
c_gs_ms_os_strides, alpha,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, input_permute,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, output_permute);
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, if(!gemm.IsSupportedArgument(argument))
a_element_op, {
b0_element_op, std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
ck::index_t BatchCount = G0 * G1; // return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); ck::index_t BatchCount = G0 * G1;
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float gb_per_sec = num_btype / 1.E6 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " float gb_per_sec = num_btype / 1.E6 / ave_time;
<< gemm.GetTypeString() << std::endl;
if(do_verification) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
{ << " GB/s, " << gemm.GetTypeString() << std::endl;
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); if(tflops > best_perf)
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; best_perf = tflops;
atol = 1e-2; best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
} }
if(do_verification)
return ck::utils::check_err(c_gs_ms_os_device_result.mData, {
c_gs_ms_os_host_result.mData, c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
"Error: Incorrect results!",
rtol, Tensor<ADataType> a_g_m_k({BatchCount, M, K});
atol) Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
? 0 Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
: 1; Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
} Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
return 0;
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 64;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> kv_gs_ns_ks_lengths{G0, G1, N, 2, K};
std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{
N * G1 * 2 * K, 2 * K, G1 * 2 * K, K, 1}; // kv layout [G0, M, G1, 2, K]
Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides);
// merge kv into a packed pointer send to device
b0_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); });
b1_gs_os_ns.ForEach(
[&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[3], 1, idx[2]) = self(idx); });
DeviceMem q_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kv_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() +
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(a_gs_ms_ks.mData.data());
kv_device_buf.ToDevice(kv_gs_ns_ks.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeCrossAttnInvoker();
auto argument =
gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
N,
G1,
K,
alpha);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{G0, G1, M, 3, K};
std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{
M * G1 * 3 * K, 3 * K, G1 * 3 * K, K, 1}; // qkv layout [G0, M, G1, 3, K]
Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides);
// merge qkv into a packed pointer send to device
a_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); });
b0_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 1, idx[3]) = self(idx); });
b1_gs_os_ns.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[3], 2, idx[2]) = self(idx); });
DeviceMem qkv_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize() +
sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() +
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
qkv_device_buf.ToDevice(qkv_gs_ms_ks.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeSelfAttnInvoker();
auto argument =
gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
G1,
K,
alpha);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
Acc0DataType,
Acc1DataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
Acc1DataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_self_attention.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -5,7 +5,11 @@ ...@@ -5,7 +5,11 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -22,6 +26,417 @@ namespace ck { ...@@ -22,6 +26,417 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename DeviceOp,
typename GridwiseOp,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
input_permute
? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::array<ck::index_t, array_size>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
input_permute
? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::array<ck::index_t, array_size>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
input_permute
? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::array<ck::index_t, array_size>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
std::array<ck::index_t, array_size> c_gs_ms_os_strides =
output_permute
? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::array<ck::index_t, array_size>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
// fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = M;
ignore = N;
ignore = K;
ignore = O;
ignore = G0;
ignore = G1;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx1100__))
}
// Self-Attention
template <typename DeviceOp,
typename GridwiseOp,
typename QKVDataType,
typename ODataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size};
std::array<ck::index_t, array_size> qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1};
std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length};
std::array<ck::index_t, array_size> v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size);
#ifdef CK_SELF_ATTN_DEBUG
if(get_thread_global_1d_id() == 0)
{
printf("batch_size: %d\n", batch_size);
printf("sequence_length: %d\n", sequence_length);
printf("head_count: %d\n", head_count);
printf("head_size: %d\n", head_size);
printf("qkv_gap: %d\n", qkv_gap);
printf("get_grid_size(): %d\n", get_grid_size());
printf("batch_count: %d\n", batch_count);
printf("blockid: %d\n", get_block_1d_id());
printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
printf("g_idx: %d\n", g_idx);
printf("a_batch_offset: %ld\n", a_batch_offset);
printf("b0_batch_offset: %ld\n", b0_batch_offset);
printf("b1_batch_offset: %ld\n", b1_batch_offset);
}
#endif
GridwiseOp::template Run<HasMainKBlockLoop>(p_qkv_grid + 0 * qkv_gap + a_batch_offset,
p_qkv_grid + 1 * qkv_gap + b0_batch_offset,
p_qkv_grid + 2 * qkv_gap + b1_batch_offset,
p_out_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_qkv_grid;
ignore = p_out_grid;
ignore = batch_size;
ignore = sequence_length;
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
}
// Cross-Attention
// Self-Attention
template <typename DeviceOp,
typename GridwiseOp,
typename QDataType,
typename KVDataType,
typename ODataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
const KVDataType* __restrict__ p_kv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size};
std::array<ck::index_t, array_size> q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
std::array<ck::index_t, array_size> k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size};
std::array<ck::index_t, array_size> k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1};
std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length};
std::array<ck::index_t, array_size> v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size);
#ifdef CK_SELF_ATTN_DEBUG
if(get_thread_global_1d_id() == 0)
{
printf("batch_size: %d\n", batch_size);
printf("q_sequence_length: %d\n", q_sequence_length);
printf("k_sequence_length: %d\n", kv_sequence_length);
printf("head_count: %d\n", head_count);
printf("head_size: %d\n", head_size);
printf("kv_gap: %d\n", kv_gap);
printf("get_grid_size(): %d\n", get_grid_size());
printf("batch_count: %d\n", batch_count);
printf("blockid: %d\n", get_block_1d_id());
printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
printf("g_idx: %d\n", g_idx);
printf("a_batch_offset: %ld\n", a_batch_offset);
printf("b0_batch_offset: %ld\n", b0_batch_offset);
printf("b1_batch_offset: %ld\n", b1_batch_offset);
}
#endif
GridwiseOp::template Run<HasMainKBlockLoop>(p_q_grid + a_batch_offset,
p_kv_grid + 0 * kv_gap + b0_batch_offset,
p_kv_grid + 1 * kv_gap + b1_batch_offset,
p_out_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_q_grid;
ignore = p_kv_grid;
ignore = p_out_grid;
ignore = batch_size;
ignore = q_sequence_length;
ignore = kv_sequence_length;
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
}
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// MN = MK * KL * LN // MN = MK * KL * LN
// ^^^^^^ (Acc0) // ^^^^^^ (Acc0)
...@@ -55,7 +470,8 @@ template <index_t NumDimG, ...@@ -55,7 +470,8 @@ template <index_t NumDimG,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t LPerBlock, ck::index_t LPerBlock,
ck::index_t KPerBlock, ck::index_t KPerBlock,
ck::index_t K1, ck::index_t AK1,
ck::index_t BK1,
ck::index_t NPerBlock, ck::index_t NPerBlock,
ck::index_t LTilePerBlock, ck::index_t LTilePerBlock,
ck::index_t L1, ck::index_t L1,
...@@ -165,14 +581,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -165,14 +581,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec, B1Spec,
CSpec>; CSpec>;
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, __host__ __device__ static auto MakeAGridDescriptor(
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
if constexpr(AEnableLds) if constexpr(AEnableLds)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<K1>{}); Number<AK1>{});
} }
else else
{ {
...@@ -184,19 +601,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -184,19 +601,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<MRepeat>{}, Number<MRepeat>{},
Number<MWaves>{}, Number<MWaves>{},
Number<MPerWmma>{}, Number<MPerWmma>{},
Number<K1>{}); Number<AK1>{});
} }
} }
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec, __host__ __device__ static auto MakeB0GridDescriptor(
const std::vector<index_t>& b0_gs_ls_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
{ {
if constexpr(B0EnableLds) if constexpr(B0EnableLds)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec), b0_gs_ls_ks_strides_vec),
Number<K1>{}); Number<BK1>{});
} }
else else
{ {
...@@ -208,12 +626,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -208,12 +626,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<LRepeat>{}, Number<LRepeat>{},
Number<LWaves>{}, Number<LWaves>{},
Number<LPerWmma>{}, Number<LPerWmma>{},
Number<K1>{}); Number<BK1>{});
} }
} }
static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec, __host__ __device__ static auto MakeB1GridDescriptor(
const std::vector<index_t>& b1_gs_ns_ls_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
{ {
if constexpr(B1EnableLds) if constexpr(B1EnableLds)
{ {
...@@ -245,7 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -245,7 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() __host__ __device__ constexpr static auto make_MaskOutPredicate()
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
...@@ -260,10 +679,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -260,10 +679,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
const CGridDesc_G_M_N& c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
...@@ -324,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -324,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
MPerBlock, MPerBlock,
LPerBlock, LPerBlock,
KPerBlock, KPerBlock,
K1, AK1,
BK1,
NPerBlock, NPerBlock,
LTilePerBlock, LTilePerBlock,
L1, L1,
...@@ -373,6 +793,323 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -373,6 +793,323 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
struct RawArg : public BaseArgument
{
RawArg(const ADataType* p_a_grid,
const B0DataType* p_b0_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
: p_a_grid_{p_a_grid},
p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
M_{M},
N_{N},
K_{K},
O_{O},
G0_{G0},
G1_{G1},
alpha_{alpha},
input_permute_{input_permute},
output_permute_{output_permute}
{
}
// Pointers
const ADataType* p_a_grid_;
const B0DataType* p_b0_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
// Raw Problem Size
index_t M_;
index_t N_;
index_t K_;
index_t O_;
index_t G0_;
index_t G1_;
float alpha_;
bool input_permute_;
bool output_permute_;
};
static auto MakeArgument(const ADataType* p_a,
const B0DataType* p_b0,
const B1DataType* p_b1,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
{
return RawArg{
p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
}
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
printf("DeviceOp: Acc0 Type err");
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
printf("DeviceOp: Acc1 Type err");
return false;
}
}
else
{
printf("DeviceOp: Arch err");
return false;
}
constexpr index_t array_size = 4;
ck::index_t G0 = arg.G0_;
ck::index_t G1 = arg.G1_;
ck::index_t M = arg.M_;
ck::index_t N = arg.N_;
ck::index_t K = arg.K_;
ck::index_t O = arg.O_;
bool input_permute = arg.input_permute_;
bool output_permute = arg.output_permute_;
std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
// A layout [G0, M, G1, K]
: std::array<ck::index_t, array_size>{
G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
// B0 layout [G0, N, G1, K]
: std::array<ck::index_t, array_size>{
G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
// B1 layout [G0, N, G1, O]
: std::array<ck::index_t, array_size>{
G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
std::array<ck::index_t, array_size> c_gs_ms_os_strides =
output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
// C layout [G0, M, G1, O]
: std::array<ck::index_t, array_size>{
G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
const auto a_grid_desc =
DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
if(!GridwiseOp::CheckValidity(
a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded
if(!(c_g == batch_count))
{
printf("DeviceOp: BatchCount err");
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = M;
const auto LzRaw = N;
const auto KzRaw = K;
const auto NzRaw = O;
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
printf("DeviceOp: Data Transfer Vector scalar err");
return false;
}
std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
c_gs_ms_os_strides[NumDimG + NumDimM - 1],
c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
const auto c_stride_lowest = c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
printf("DeviceOp: Data Vectorize transfer err");
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
}
struct SelfAttnArg : public BaseArgument
{
SelfAttnArg(const ADataType* p_qkv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
: p_qkv_grid_{p_qkv_grid},
p_out_grid_{p_out_grid},
batch_size_{batch_size},
sequence_length_{sequence_length},
head_count_{head_count},
head_size_{head_size},
alpha_{alpha}
{
}
// Pointers
const ADataType* p_qkv_grid_;
CDataType* p_out_grid_;
// Raw Problem Size
index_t batch_size_;
index_t sequence_length_;
index_t head_count_;
index_t head_size_;
float alpha_;
};
static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
return SelfAttnArg{
p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha};
}
struct CrossAttnArg : public BaseArgument
{
CrossAttnArg(const ADataType* p_q_grid,
const B0DataType* p_kv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
: p_q_grid_{p_q_grid},
p_kv_grid_{p_kv_grid},
p_out_grid_{p_out_grid},
batch_size_{batch_size},
q_sequence_length_{q_sequence_length},
kv_sequence_length_{kv_sequence_length},
head_count_{head_count},
head_size_{head_size},
alpha_{alpha}
{
}
// Pointers
const ADataType* p_q_grid_;
const B0DataType* p_kv_grid_;
CDataType* p_out_grid_;
// Raw Problem Size
index_t batch_size_;
index_t q_sequence_length_;
index_t kv_sequence_length_;
index_t head_count_;
index_t head_size_;
float alpha_;
};
static auto MakeCrossAttnArgument(const ADataType* p_q_grid,
const B0DataType* p_kv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
return CrossAttnArg{p_q_grid,
p_kv_grid,
p_out_grid,
batch_size,
q_sequence_length,
kv_sequence_length,
head_count,
head_size,
alpha};
}
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -383,14 +1120,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -383,14 +1120,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid, CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
...@@ -497,11 +1234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -497,11 +1234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Strides for the last M/N/K dimensions of A/B0/B1/C // Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store // for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_lz_kz_nz_; std::array<index_t, NumDimG + NumDimM + NumDimN> raw_lengths_mz_lz_kz_nz_;
std::vector<index_t> a_mz_kz_strides_; std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_;
std::vector<index_t> b0_lz_kz_strides_; std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_;
std::vector<index_t> b1_nz_lz_strides_; std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_;
std::vector<index_t> c_mz_nz_strides_; std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_;
index_t batch_count_; index_t batch_count_;
// Batch Offset // Batch Offset
...@@ -509,47 +1246,152 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -509,47 +1246,152 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}; };
// Invoker // Invoker
struct SelfAttnInvoker : public BaseInvoker
{
using Argument = DeviceOp::SelfAttnArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
const auto K = arg.head_size_;
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_wmma_self_attention_forward<DeviceOp,
GridwiseOp,
ADataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_qkv_grid_,
arg.p_out_grid_,
arg.batch_size_,
arg.sequence_length_,
arg.head_count_,
arg.head_size_,
arg.alpha_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; }
// Invoker
struct CrossAttnInvoker : public BaseInvoker
{
using Argument = DeviceOp::CrossAttnArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
const auto K = arg.head_size_;
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_wmma_cross_attention_forward<DeviceOp,
GridwiseOp,
ADataType,
B0DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_q_grid_,
arg.p_kv_grid_,
arg.p_out_grid_,
arg.batch_size_,
arg.q_sequence_length_,
arg.kv_sequence_length_,
arg.head_count_,
arg.head_size_,
arg.alpha_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; }
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const index_t grid_size = const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock);
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock);
const auto K = [&]() {
if constexpr(AEnableLds)
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
}
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
}
}();
const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0;
const auto K = arg.K_;
// printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle< const auto kernel =
GridwiseOp, kernel_batched_gemm_softmax_gemm_wmma_cshuffle<DeviceOp,
ADataType, GridwiseOp,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
DeviceOp::AGridDesc, CDataType,
DeviceOp::B0GridDesc, AElementwiseOperation,
DeviceOp::B1GridDesc, B0ElementwiseOperation,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, AccElementwiseOperation,
AElementwiseOperation, B1ElementwiseOperation,
B0ElementwiseOperation, CElementwiseOperation,
AccElementwiseOperation, has_main_k_block_loop>;
B1ElementwiseOperation,
CElementwiseOperation,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
typename GridwiseOp::DefaultBlock2CTileMap,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -560,19 +1402,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -560,19 +1402,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b0_grid_, arg.p_b0_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc, arg.M_,
arg.b0_grid_desc, arg.N_,
arg.b1_grid_desc, arg.K_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.O_,
arg.a_element_op_, arg.G0_,
arg.b0_element_op_, arg.G1_,
arg.acc_element_op_, arg.alpha_,
arg.b1_element_op_, arg.input_permute_,
arg.c_element_op_, arg.output_permute_);
arg.batch_count_,
arg.compute_ptr_offset_of_batch_,
arg.c0_matrix_mask_,
arg.block_2_ctile_map_);
}; };
if(GridwiseOp::CalculateHasMainKBlockLoop(K)) if(GridwiseOp::CalculateHasMainKBlockLoop(K))
...@@ -598,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -598,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// TODO: properly implement this check // TODO: properly implement this check
return true; return true;
} }
#if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
...@@ -695,14 +1533,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -695,14 +1533,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c, CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
...@@ -739,6 +1577,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -739,6 +1577,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_element_op, b1_element_op,
c_element_op}; c_element_op};
} }
#endif
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
...@@ -766,20 +1605,60 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -766,20 +1605,60 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op) override
{ {
std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
std::transform(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.end(),
a_lengths.begin(),
[](index_t i) { return i; });
std::transform(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.end(),
a_strides.begin(),
[](index_t i) { return i; });
std::transform(b0_gs_ls_ks_lengths.begin(),
b0_gs_ls_ks_lengths.end(),
b0_lengths.begin(),
[](index_t i) { return i; });
std::transform(b0_gs_ls_ks_strides.begin(),
b0_gs_ls_ks_strides.end(),
b0_strides.begin(),
[](index_t i) { return i; });
std::transform(b1_gs_ns_ls_lengths.begin(),
b1_gs_ns_ls_lengths.end(),
b1_lengths.begin(),
[](index_t i) { return i; });
std::transform(b1_gs_ns_ls_strides.begin(),
b1_gs_ns_ls_strides.end(),
b1_strides.begin(),
[](index_t i) { return i; });
std::transform(c_gs_ms_ns_lengths.begin(),
c_gs_ms_ns_lengths.end(),
c_lengths.begin(),
[](index_t i) { return i; });
std::transform(c_gs_ms_ns_strides.begin(),
c_gs_ms_ns_strides.end(),
c_strides.begin(),
[](index_t i) { return i; });
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0), static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
p_acc0_biases, p_acc0_biases,
p_acc1_biases, p_acc1_biases,
a_gs_ms_ks_lengths, a_lengths,
a_gs_ms_ks_strides, a_strides,
b0_gs_ls_ks_lengths, b0_lengths,
b0_gs_ls_ks_strides, b0_strides,
b1_gs_ns_ls_lengths, b1_lengths,
b1_gs_ns_ls_strides, b1_strides,
c_gs_ms_ns_lengths, c_lengths,
c_gs_ms_ns_strides, c_strides,
acc0_biases_gs_ms_ls_lengths, acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides, acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths, acc1_biases_gs_ms_ns_lengths,
...@@ -819,11 +1698,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -819,11 +1698,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< MPerBlock << ", " << MPerBlock << ", "
<< LPerBlock << ", " << LPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< K1 << ", " << AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< LTilePerBlock << ", " << LTilePerBlock << ", "
<< L1 << L1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", " << getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", " << "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(B0Spec) << ", " << "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
......
...@@ -650,7 +650,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -650,7 +650,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -667,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -667,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv // check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
......
...@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate ...@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate> template <typename MaskOutPredicate>
struct C0MatrixMask_impl struct C0MatrixMask_impl
{ {
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} __host__ __device__ C0MatrixMask_impl(index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{ {
......
...@@ -18,102 +18,6 @@ ...@@ -18,102 +18,6 @@
namespace ck { namespace ck {
template <typename GridwiseOp,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AGridDesc,
typename B0GridDesc,
typename B1GridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const B0GridDesc b0_grid_desc,
const B1GridDesc b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const B0ElementwiseOperation b0_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_grid_desc;
ignore = b0_grid_desc;
ignore = b1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b0_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] // Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] // Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <typename ADataType, template <typename ADataType,
...@@ -136,7 +40,8 @@ template <typename ADataType, ...@@ -136,7 +40,8 @@ template <typename ADataType,
index_t MPerBlock, index_t MPerBlock,
index_t LPerBlock, index_t LPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t K1Value, index_t AK1Value,
index_t BK1Value,
index_t NPerBlock, index_t NPerBlock,
index_t LTilePerBlock, index_t LTilePerBlock,
index_t L1Value, index_t L1Value,
...@@ -194,9 +99,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -194,9 +99,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<K1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / K1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto BK1 = Number<K1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto L0PerBlock = LTilePerBlock / L1Value; static constexpr auto L0PerBlock = LTilePerBlock / L1Value;
static constexpr auto AL0 = Number<L0PerBlock / 2>{}; static constexpr auto AL0 = Number<L0PerBlock / 2>{};
...@@ -714,7 +619,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -714,7 +619,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = math::integer_divide_ceil(K, KPerBlock);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
...@@ -887,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -887,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value; constexpr auto K0PerWmma = WmmaK/2/AK1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>( auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize()); a_block_desc.GetElementSpaceSize());
...@@ -903,7 +808,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -903,7 +808,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number<K0PerWmma>{}, Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<AK1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>, Sequence<0, 1, 2, 3, 4, 5, 6>,
6, 6,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
...@@ -966,7 +871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -966,7 +871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy // Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value; constexpr auto K0PerWmma = WmmaK/2/BK1Value;
auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>( auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>(
b0_block_desc.GetElementSpaceSize()); b0_block_desc.GetElementSpaceSize());
...@@ -982,7 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -982,7 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number<K0PerWmma>{}, Number<K0PerWmma>{},
I1, I1,
I1, I1,
Number<K1Value>{}>, Number<AK1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>, Sequence<0, 1, 2, 3, 4, 5, 6>,
6, 6,
B0BlockTransferSrcScalarPerVector, B0BlockTransferSrcScalarPerVector,
...@@ -1009,7 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma ...@@ -1009,7 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
/*******************************************************************************/ /*******************************************************************************/
// Gemm0 // Gemm0
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK); constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK);
auto blockwise_gemm0 = BlockwiseGemmWMMA< auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize, BlockSize,
......
...@@ -16,14 +16,15 @@ template <index_t NumDimG, ...@@ -16,14 +16,15 @@ template <index_t NumDimG,
index_t NumDimM, index_t NumDimM,
index_t NumDimN, index_t NumDimN,
device::TensorSpecialization TensorSpec> device::TensorSpecialization TensorSpec>
static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec, __host__ __device__ static auto
const std::vector<index_t>& gs_ms_ns_strides_vec) MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
{ {
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
{ // {
throw std::runtime_error("wrong! dimension must match input lengths"); // throw std::runtime_error("wrong! dimension must match input lengths");
} // }
const auto to_tuple = [&](auto& vec, auto start, auto end) { const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{}); return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
...@@ -143,21 +144,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -143,21 +144,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
// //
// A // A
// //
static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, __host__ __device__ static auto MakeAGridDescriptorPair(
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec, return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec); a_gs_ms_ks_strides_vec);
} }
// TODO: rename to G_MRaw_KRaw // TODO: rename to G_MRaw_KRaw
static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, __host__ __device__ static auto MakeAGridDescriptor_G_M_K(
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
} }
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, __host__ __device__ static auto MakeAGridDescriptor_M_K(
const std::vector<index_t>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
return matrix_padder.PadADescriptor_M_K( return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
...@@ -212,21 +216,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -212,21 +216,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
// //
// B (alias of B0) // B (alias of B0)
// //
static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec, __host__ __device__ static auto MakeB0GridDescriptorPair(
const std::vector<index_t>& b0_gs_ns_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{ {
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec, return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec); b0_gs_ns_ks_strides_vec);
} }
// TODO: rename to G_MRaw_NRaw // TODO: rename to G_MRaw_NRaw
static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec, __host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
const std::vector<index_t>& b0_gs_ns_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{ {
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
} }
static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec, __host__ __device__ static auto MakeB0GridDescriptor_N_K(
const std::vector<index_t>& b0_gs_ns_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{ {
// alias of matrix_padder.PadB0Descriptor_N_K // alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K( return matrix_padder.PadBDescriptor_N_K(
...@@ -282,21 +289,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -282,21 +289,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
// //
// B1 // B1
// //
static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec, __host__ __device__ static auto MakeB1GridDescriptorPair(
const std::vector<index_t>& b1_gs_os_ns_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{ {
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec, return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec); b1_gs_os_ns_strides_vec);
} }
// TODO: rename to G_NRaw_KRaw // TODO: rename to G_NRaw_KRaw
static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec, __host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
const std::vector<index_t>& b1_gs_os_ns_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{ {
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
} }
static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec, __host__ __device__ static auto MakeB1GridDescriptor_N_K(
const std::vector<index_t>& b1_gs_os_ns_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{ {
// alias of matrix_padder.PadB1Descriptor_O_N // alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K( return matrix_padder.PadB1Descriptor_N_K(
...@@ -353,21 +363,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -353,21 +363,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
// //
// C // C
// //
static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec, __host__ __device__ static auto MakeCGridDescriptorPair(
const std::vector<index_t>& c_gs_ms_os_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{ {
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec, return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec); c_gs_ms_os_strides_vec);
} }
// TODO: rename to G_MRaw_NRaw // TODO: rename to G_MRaw_NRaw
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec, __host__ __device__ static auto MakeCGridDescriptor_G_M_N(
const std::vector<index_t>& c_gs_ms_os_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{ {
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
} }
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec, __host__ __device__ static auto MakeCGridDescriptor_M_N(
const std::vector<index_t>& c_gs_ms_os_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{ {
return matrix_padder.PadCDescriptor_M_N( return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
......
#!/bin/bash
while getopts e: flag
do
case "${flag}" in
e) executable=${OPTARG};;
esac
done
echo "CK-NAVI31 Performance Test: MHA for AITemplate"
VERIFICATION=0
INITIALIZE=1
TIMING=1
ALL_TEST_CASE=0
SELF_ATTENTION=1
CROSS_ATTENTION=0
CAUSAL_MASK=0
# self attention with causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $SELF_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 1 ]; }; then
echo "Test launched: self attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 4096 4096 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 1024 1024 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 256 256 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
# cross attention with causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $CROSS_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 1 ]; }; then
echo "Test launched: cross attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 4096 64 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 1024 64 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 256 64 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
# self attention without causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $SELF_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 0 ]; }; then
echo "Test launched: self attention without causal mask"
$executable $VERIFICATION $INITIALIZE $TIMING 4096 4096 64 64 2 5 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 1024 1024 64 64 2 10 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 256 256 64 64 2 20 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 64 64 64 64 2 20 0.125 1 1
fi
# cross attention without causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $CROSS_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 0 ]; }; then
echo "Test launched: cross attention without causal mask"
$executable $VERIFICATION 1 $TIMING 4096 64 40 40 2 8 0.158113881945610 1 1
$executable $VERIFICATION 1 $TIMING 1024 64 80 80 2 8 0.111803397536277 1 1
$executable $VERIFICATION 1 $TIMING 256 64 160 160 2 8 0.079056940972805 1 1
$executable $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment