"vscode:/vscode.git/clone" did not exist on "9e36648284c47fda760e11f10bb0a6968b78ae8e"
Commit 823c8801 authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'e2e_kernellib' of https://github.com/aska-0096/navi3x_ck into e2e_kernellib

parents e305e41e efee4541
......@@ -8,6 +8,8 @@ add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_pe
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm)
......
......@@ -67,77 +67,200 @@ static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecial
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
// clang-format off
// #define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
B0DataType,
B1DataType,
CDataType,
Acc0BiasDataType,
Acc0DataType,
Acc1BiasDataType,
Acc1DataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecA,
TensorSpecB0,
TensorSpecB1,
TensorSpecC,
1,
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, // MPerBlock
64, // LPerBlock
64, // KPerBlock
8, // K1
128, 128, 64, 8, 8,
// Gemm 1
64, // NPerBlock
64, // LTilePerBlock
8, // L1
16, // MPerWMMA
16, // LPerWMMA
16, // NPerWMMA
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, // MRepeat
4, // LRepeat
4, // NRepeat
S<4, 64, 1>, // ABlockTransfer MK -> K0 M K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1
S<0, 2, 1>,
S<0, 2, 1>,
1,
8,
1,
false,
1, // CShuffleMWmmaPerWavePerShuffle
2, // CShuffleNWmmaPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 192, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 192, 48, 8,4,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
Acc0DataType,
Acc1DataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
Acc1DataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_cross_attention.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -127,16 +127,31 @@ int run(int argc, char* argv[])
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: b1 pass
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0 pass
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
......@@ -160,39 +175,37 @@ int run(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{};
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
M,
N,
K,
O,
G0,
G1,
alpha,
input_permute,
output_permute);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
// return 0;
}
ck::index_t BatchCount = G0 * G1;
......@@ -208,9 +221,14 @@ int run(int argc, char* argv[])
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
......@@ -242,7 +260,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......@@ -258,8 +276,12 @@ int run(int argc, char* argv[])
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
......@@ -285,14 +307,34 @@ int run(int argc, char* argv[])
atol = 1e-2;
}
return ck::utils::check_err(c_gs_ms_os_device_result.mData,
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol)
? 0
: 1;
}
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
return 0;
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 64;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> kv_gs_ns_ks_lengths{G0, G1, N, 2, K};
std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{
N * G1 * 2 * K, 2 * K, G1 * 2 * K, K, 1}; // kv layout [G0, M, G1, 2, K]
Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides);
// merge kv into a packed pointer send to device
b0_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); });
b1_gs_os_ns.ForEach(
[&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[3], 1, idx[2]) = self(idx); });
DeviceMem q_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kv_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() +
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(a_gs_ms_ks.mData.data());
kv_device_buf.ToDevice(kv_gs_ns_ks.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeCrossAttnInvoker();
auto argument =
gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
N,
G1,
K,
alpha);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
float alpha = 1;
bool input_permute = false;
bool output_permute = true;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl;
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{G0, G1, M, 3, K};
std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{
M * G1 * 3 * K, 3 * K, G1 * 3 * K, K, 1}; // qkv layout [G0, M, G1, 3, K]
Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides);
// merge qkv into a packed pointer send to device
a_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); });
b0_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 1, idx[3]) = self(idx); });
b1_gs_os_ns.ForEach(
[&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[3], 2, idx[2]) = self(idx); });
DeviceMem qkv_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize() +
sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() +
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
qkv_device_buf.ToDevice(qkv_gs_ms_ks.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeSelfAttnInvoker();
auto argument =
gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
G1,
K,
alpha);
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// softmax
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
*/
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using Acc0DataType = F32;
using Acc1DataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
// clang-format off
#define CK_MHA_USE_WAVE_1
#define CK_MHA_USE_WAVE_2
#define CK_MHA_USE_WAVE_4
#define CK_MHA_USE_WAVE_8
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
// Gemm 1
64, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
// clang-format on
// Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
Acc0DataType,
Acc1DataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Softmax: fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<Acc0DataType, ADataType, Acc0DataType>;
// Ref Gemm1: fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
Acc1DataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_self_attention.inc"
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -5,7 +5,11 @@
#include <iostream>
#include <sstream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -22,6 +26,417 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename DeviceOp,
typename GridwiseOp,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
input_permute
? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::array<ck::index_t, array_size>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
input_permute
? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::array<ck::index_t, array_size>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
input_permute
? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::array<ck::index_t, array_size>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
std::array<ck::index_t, array_size> c_gs_ms_os_strides =
output_permute
? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::array<ck::index_t, array_size>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
// fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = M;
ignore = N;
ignore = K;
ignore = O;
ignore = G0;
ignore = G1;
ignore = input_permute;
ignore = output_permute;
#endif // end of if (defined(__gfx1100__))
}
// Self-Attention
template <typename DeviceOp,
typename GridwiseOp,
typename QKVDataType,
typename ODataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size};
std::array<ck::index_t, array_size> qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1};
std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length};
std::array<ck::index_t, array_size> v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size);
#ifdef CK_SELF_ATTN_DEBUG
if(get_thread_global_1d_id() == 0)
{
printf("batch_size: %d\n", batch_size);
printf("sequence_length: %d\n", sequence_length);
printf("head_count: %d\n", head_count);
printf("head_size: %d\n", head_size);
printf("qkv_gap: %d\n", qkv_gap);
printf("get_grid_size(): %d\n", get_grid_size());
printf("batch_count: %d\n", batch_count);
printf("blockid: %d\n", get_block_1d_id());
printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
printf("g_idx: %d\n", g_idx);
printf("a_batch_offset: %ld\n", a_batch_offset);
printf("b0_batch_offset: %ld\n", b0_batch_offset);
printf("b1_batch_offset: %ld\n", b1_batch_offset);
}
#endif
GridwiseOp::template Run<HasMainKBlockLoop>(p_qkv_grid + 0 * qkv_gap + a_batch_offset,
p_qkv_grid + 1 * qkv_gap + b0_batch_offset,
p_qkv_grid + 2 * qkv_gap + b1_batch_offset,
p_out_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_qkv_grid;
ignore = p_out_grid;
ignore = batch_size;
ignore = sequence_length;
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
}
// Cross-Attention
// Self-Attention
template <typename DeviceOp,
typename GridwiseOp,
typename QDataType,
typename KVDataType,
typename ODataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
const KVDataType* __restrict__ p_kv_grid,
ODataType* __restrict__ p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
// clang-format off
// ***************************************************
// Make Tensor Descriptors
// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
constexpr index_t array_size = 4;
std::array<ck::index_t, array_size> q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size};
std::array<ck::index_t, array_size> q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
std::array<ck::index_t, array_size> k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size};
std::array<ck::index_t, array_size> k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1};
std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length};
std::array<ck::index_t, array_size> v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size};
std::array<ck::index_t, array_size> c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
const auto a_element_op = AElementwiseOperation{};
const auto b0_element_op = B0ElementwiseOperation{};
const auto acc0_element_op = AccElementwiseOperation{alpha};
const auto b1_element_op = B1ElementwiseOperation{};
const auto c_element_op = CElementwiseOperation{};
const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto a_grid_desc_g_m_k =
DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
const auto b0_grid_desc_g_l_k =
DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
const auto b1_grid_desc_g_n_l =
DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
// clang-format on
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size);
#ifdef CK_SELF_ATTN_DEBUG
if(get_thread_global_1d_id() == 0)
{
printf("batch_size: %d\n", batch_size);
printf("q_sequence_length: %d\n", q_sequence_length);
printf("k_sequence_length: %d\n", kv_sequence_length);
printf("head_count: %d\n", head_count);
printf("head_size: %d\n", head_size);
printf("kv_gap: %d\n", kv_gap);
printf("get_grid_size(): %d\n", get_grid_size());
printf("batch_count: %d\n", batch_count);
printf("blockid: %d\n", get_block_1d_id());
printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
printf("g_idx: %d\n", g_idx);
printf("a_batch_offset: %ld\n", a_batch_offset);
printf("b0_batch_offset: %ld\n", b0_batch_offset);
printf("b1_batch_offset: %ld\n", b1_batch_offset);
}
#endif
GridwiseOp::template Run<HasMainKBlockLoop>(p_q_grid + a_batch_offset,
p_kv_grid + 0 * kv_gap + b0_batch_offset,
p_kv_grid + 1 * kv_gap + b1_batch_offset,
p_out_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_q_grid;
ignore = p_kv_grid;
ignore = p_out_grid;
ignore = batch_size;
ignore = q_sequence_length;
ignore = kv_sequence_length;
ignore = head_count;
ignore = head_size;
ignore = alpha;
#endif // end of if (defined(__gfx1100__))
}
// Computes C = A * B0 * B1
// MN = MK * KL * LN
// ^^^^^^ (Acc0)
......@@ -55,7 +470,8 @@ template <index_t NumDimG,
ck::index_t MPerBlock,
ck::index_t LPerBlock,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t NPerBlock,
ck::index_t LTilePerBlock,
ck::index_t L1,
......@@ -165,14 +581,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1Spec,
CSpec>;
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
__host__ __device__ static auto MakeAGridDescriptor(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
if constexpr(AEnableLds)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<K1>{});
Number<AK1>{});
}
else
{
......@@ -184,19 +601,20 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<MRepeat>{},
Number<MWaves>{},
Number<MPerWmma>{},
Number<K1>{});
Number<AK1>{});
}
}
static auto MakeB0GridDescriptor(const std::vector<index_t>& b0_gs_ls_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ls_ks_strides_vec)
__host__ __device__ static auto MakeB0GridDescriptor(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
{
if constexpr(B0EnableLds)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
b0_gs_ls_ks_strides_vec),
Number<K1>{});
Number<BK1>{});
}
else
{
......@@ -208,12 +626,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
Number<LRepeat>{},
Number<LWaves>{},
Number<LPerWmma>{},
Number<K1>{});
Number<BK1>{});
}
}
static auto MakeB1GridDescriptor(const std::vector<index_t>& b1_gs_ns_ls_lengths_vec,
const std::vector<index_t>& b1_gs_ns_ls_strides_vec)
__host__ __device__ static auto MakeB1GridDescriptor(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
{
if constexpr(B1EnableLds)
{
......@@ -245,7 +664,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
__host__ __device__ constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
......@@ -260,7 +679,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
const CGridDesc_G_M_N& c_grid_desc_g_m_n)
......@@ -324,7 +743,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
MPerBlock,
LPerBlock,
KPerBlock,
K1,
AK1,
BK1,
NPerBlock,
LTilePerBlock,
L1,
......@@ -373,6 +793,323 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
LoopSched,
PipelineVer>;
struct RawArg : public BaseArgument
{
RawArg(const ADataType* p_a_grid,
const B0DataType* p_b0_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
: p_a_grid_{p_a_grid},
p_b0_grid_{p_b0_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
M_{M},
N_{N},
K_{K},
O_{O},
G0_{G0},
G1_{G1},
alpha_{alpha},
input_permute_{input_permute},
output_permute_{output_permute}
{
}
// Pointers
const ADataType* p_a_grid_;
const B0DataType* p_b0_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
// Raw Problem Size
index_t M_;
index_t N_;
index_t K_;
index_t O_;
index_t G0_;
index_t G1_;
float alpha_;
bool input_permute_;
bool output_permute_;
};
static auto MakeArgument(const ADataType* p_a,
const B0DataType* p_b0,
const B1DataType* p_b1,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t O,
index_t G0,
index_t G1,
float alpha,
bool input_permute,
bool output_permute)
{
return RawArg{
p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
}
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
printf("DeviceOp: Acc0 Type err");
return false;
}
if constexpr(!(is_same_v<Acc1DataType, float> || is_same_v<Acc1DataType, int32_t>))
{
printf("DeviceOp: Acc1 Type err");
return false;
}
}
else
{
printf("DeviceOp: Arch err");
return false;
}
constexpr index_t array_size = 4;
ck::index_t G0 = arg.G0_;
ck::index_t G1 = arg.G1_;
ck::index_t M = arg.M_;
ck::index_t N = arg.N_;
ck::index_t K = arg.K_;
ck::index_t O = arg.O_;
bool input_permute = arg.input_permute_;
bool output_permute = arg.output_permute_;
std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
// A layout [G0, M, G1, K]
: std::array<ck::index_t, array_size>{
G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
// B0 layout [G0, N, G1, K]
: std::array<ck::index_t, array_size>{
G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
// B1 layout [G0, N, G1, O]
: std::array<ck::index_t, array_size>{
G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
std::array<ck::index_t, array_size> c_gs_ms_os_strides =
output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
// C layout [G0, M, G1, O]
: std::array<ck::index_t, array_size>{
G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
const auto a_grid_desc =
DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
const auto b0_grid_desc =
DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
const auto b1_grid_desc =
DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
const auto c_grid_desc_m_n =
DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
const auto c_grid_desc_g_m_n =
DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
if(!GridwiseOp::CheckValidity(
a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
{
return false;
}
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded
if(!(c_g == batch_count))
{
printf("DeviceOp: BatchCount err");
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = M;
const auto LzRaw = N;
const auto KzRaw = K;
const auto NzRaw = O;
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
const auto c_extent_lowest = NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
printf("DeviceOp: Data Transfer Vector scalar err");
return false;
}
std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
c_gs_ms_os_strides[NumDimG + NumDimM - 1],
c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
const auto b0_stride_lowest =
B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
const auto c_stride_lowest = c_mz_nz_strides_[1];
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
printf("DeviceOp: Data Vectorize transfer err");
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
}
struct SelfAttnArg : public BaseArgument
{
SelfAttnArg(const ADataType* p_qkv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
: p_qkv_grid_{p_qkv_grid},
p_out_grid_{p_out_grid},
batch_size_{batch_size},
sequence_length_{sequence_length},
head_count_{head_count},
head_size_{head_size},
alpha_{alpha}
{
}
// Pointers
const ADataType* p_qkv_grid_;
CDataType* p_out_grid_;
// Raw Problem Size
index_t batch_size_;
index_t sequence_length_;
index_t head_count_;
index_t head_size_;
float alpha_;
};
static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
return SelfAttnArg{
p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha};
}
struct CrossAttnArg : public BaseArgument
{
CrossAttnArg(const ADataType* p_q_grid,
const B0DataType* p_kv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
: p_q_grid_{p_q_grid},
p_kv_grid_{p_kv_grid},
p_out_grid_{p_out_grid},
batch_size_{batch_size},
q_sequence_length_{q_sequence_length},
kv_sequence_length_{kv_sequence_length},
head_count_{head_count},
head_size_{head_size},
alpha_{alpha}
{
}
// Pointers
const ADataType* p_q_grid_;
const B0DataType* p_kv_grid_;
CDataType* p_out_grid_;
// Raw Problem Size
index_t batch_size_;
index_t q_sequence_length_;
index_t kv_sequence_length_;
index_t head_count_;
index_t head_size_;
float alpha_;
};
static auto MakeCrossAttnArgument(const ADataType* p_q_grid,
const B0DataType* p_kv_grid,
CDataType* p_out_grid,
index_t batch_size,
index_t q_sequence_length,
index_t kv_sequence_length,
index_t head_count,
index_t head_size,
float alpha)
{
return CrossAttnArg{p_q_grid,
p_kv_grid,
p_out_grid,
batch_size,
q_sequence_length,
kv_sequence_length,
head_count,
head_size,
alpha};
}
// Argument
struct Argument : public BaseArgument
{
......@@ -383,14 +1120,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
......@@ -497,11 +1234,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Strides for the last M/N/K dimensions of A/B0/B1/C
// for sanity check of vector load/store
std::vector<index_t> raw_lengths_mz_lz_kz_nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b0_lz_kz_strides_;
std::vector<index_t> b1_nz_lz_strides_;
std::vector<index_t> c_mz_nz_strides_;
std::array<index_t, NumDimG + NumDimM + NumDimN> raw_lengths_mz_lz_kz_nz_;
std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_;
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_;
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_;
std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_;
index_t batch_count_;
// Batch Offset
......@@ -509,46 +1246,151 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
};
// Invoker
struct Invoker : public BaseInvoker
struct SelfAttnInvoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
using Argument = DeviceOp::SelfAttnArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
const auto K = [&]() {
if constexpr(AEnableLds)
const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
const auto K = arg.head_size_;
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_wmma_self_attention_forward<DeviceOp,
GridwiseOp,
ADataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_qkv_grid_,
arg.p_out_grid_,
arg.batch_size_,
arg.sequence_length_,
arg.head_count_,
arg.head_size_,
arg.alpha_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
return launch_kernel(integral_constant<bool, false>{});
}
}
}();
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; }
// Invoker
struct CrossAttnInvoker : public BaseInvoker
{
using Argument = DeviceOp::CrossAttnArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
const auto K = arg.head_size_;
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_wmma_cross_attention_forward<DeviceOp,
GridwiseOp,
ADataType,
B0DataType,
CDataType,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_q_grid_,
arg.p_kv_grid_,
arg.p_out_grid_,
arg.batch_size_,
arg.q_sequence_length_,
arg.kv_sequence_length_,
arg.head_count_,
arg.head_size_,
arg.alpha_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; }
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::RawArg;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock);
const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0;
const auto K = arg.K_;
// printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
auto launch_kernel = [&](auto has_main_k_block_loop) {
const auto kernel = kernel_batched_gemm_softmax_gemm_wmma_cshuffle<
const auto kernel =
kernel_batched_gemm_softmax_gemm_wmma_cshuffle<DeviceOp,
GridwiseOp,
ADataType,
B0DataType,
B1DataType,
CDataType,
DeviceOp::AGridDesc,
DeviceOp::B0GridDesc,
DeviceOp::B1GridDesc,
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
AElementwiseOperation,
B0ElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
typename GridwiseOp::DefaultBlock2CTileMap,
has_main_k_block_loop>;
return launch_and_time_kernel(stream_config,
......@@ -560,19 +1402,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg.p_b0_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.a_grid_desc,
arg.b0_grid_desc,
arg.b1_grid_desc,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_,
arg.b0_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.batch_count_,
arg.compute_ptr_offset_of_batch_,
arg.c0_matrix_mask_,
arg.block_2_ctile_map_);
arg.M_,
arg.N_,
arg.K_,
arg.O_,
arg.G0_,
arg.G1_,
arg.alpha_,
arg.input_permute_,
arg.output_permute_);
};
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
......@@ -598,7 +1436,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// TODO: properly implement this check
return true;
}
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
......@@ -695,14 +1533,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType* p_c,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b0_gs_ls_ks_lengths,
const std::vector<index_t>& b0_gs_ls_ks_strides,
const std::vector<index_t>& b1_gs_ns_ls_lengths,
const std::vector<index_t>& b1_gs_ns_ls_strides,
const std::vector<index_t>& c_gs_ms_ns_lengths,
const std::vector<index_t>& c_gs_ms_ns_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
......@@ -739,6 +1577,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_element_op,
c_element_op};
}
#endif
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -766,20 +1605,60 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
std::transform(a_gs_ms_ks_lengths.begin(),
a_gs_ms_ks_lengths.end(),
a_lengths.begin(),
[](index_t i) { return i; });
std::transform(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.end(),
a_strides.begin(),
[](index_t i) { return i; });
std::transform(b0_gs_ls_ks_lengths.begin(),
b0_gs_ls_ks_lengths.end(),
b0_lengths.begin(),
[](index_t i) { return i; });
std::transform(b0_gs_ls_ks_strides.begin(),
b0_gs_ls_ks_strides.end(),
b0_strides.begin(),
[](index_t i) { return i; });
std::transform(b1_gs_ns_ls_lengths.begin(),
b1_gs_ns_ls_lengths.end(),
b1_lengths.begin(),
[](index_t i) { return i; });
std::transform(b1_gs_ns_ls_strides.begin(),
b1_gs_ns_ls_strides.end(),
b1_strides.begin(),
[](index_t i) { return i; });
std::transform(c_gs_ms_ns_lengths.begin(),
c_gs_ms_ns_lengths.end(),
c_lengths.begin(),
[](index_t i) { return i; });
std::transform(c_gs_ms_ns_strides.begin(),
c_gs_ms_ns_strides.end(),
c_strides.begin(),
[](index_t i) { return i; });
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const B0DataType*>(p_b0),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ls_ks_lengths,
b0_gs_ls_ks_strides,
b1_gs_ns_ls_lengths,
b1_gs_ns_ls_strides,
c_gs_ms_ns_lengths,
c_gs_ms_ns_strides,
a_lengths,
a_strides,
b0_lengths,
b0_strides,
b1_lengths,
b1_strides,
c_lengths,
c_strides,
acc0_biases_gs_ms_ls_lengths,
acc0_biases_gs_ms_ls_strides,
acc1_biases_gs_ms_ns_lengths,
......@@ -819,11 +1698,12 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<< MPerBlock << ", "
<< LPerBlock << ", "
<< KPerBlock << ", "
<< K1 << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< LTilePerBlock << ", "
<< L1
<< L1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
......
......@@ -650,7 +650,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
......@@ -667,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
......
......@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
......
......@@ -18,102 +18,6 @@
namespace ck {
template <typename GridwiseOp,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AGridDesc,
typename B0GridDesc,
typename B1GridDesc,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
typename Block2CTileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(
const ADataType* __restrict__ p_a_grid,
const B0DataType* __restrict__ p_b0_grid,
const B1DataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid,
const AGridDesc a_grid_desc,
const B0GridDesc b0_grid_desc,
const B1GridDesc b1_grid_desc,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation a_element_op,
const B0ElementwiseOperation b0_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b0_grid + b0_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_grid_desc,
b0_grid_desc,
b1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b0_element_op,
acc_element_op,
b1_element_op,
c_element_op,
c0_matrix_mask,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b0_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_grid_desc;
ignore = b0_grid_desc;
ignore = b1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b0_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx1100__))
}
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template <typename ADataType,
......@@ -136,7 +40,8 @@ template <typename ADataType,
index_t MPerBlock,
index_t LPerBlock,
index_t KPerBlock,
index_t K1Value,
index_t AK1Value,
index_t BK1Value,
index_t NPerBlock,
index_t LTilePerBlock,
index_t L1Value,
......@@ -194,9 +99,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<K1Value>{};
static constexpr auto BK0 = Number<KPerBlock / K1Value>{};
static constexpr auto BK1 = Number<K1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto L0PerBlock = LTilePerBlock / L1Value;
static constexpr auto AL0 = Number<L0PerBlock / 2>{};
......@@ -714,7 +619,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
const index_t num_loop = math::integer_divide_ceil(K, KPerBlock);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
......@@ -887,7 +792,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
constexpr auto K0PerWmma = WmmaK/2/AK1Value;
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
a_block_desc.GetElementSpaceSize());
......@@ -903,7 +808,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Number<AK1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
ABlockTransferSrcScalarPerVector,
......@@ -966,7 +871,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
constexpr auto K0PerWmma = WmmaK/2/K1Value;
constexpr auto K0PerWmma = WmmaK/2/BK1Value;
auto b0_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, B0DataType>(
b0_block_desc.GetElementSpaceSize());
......@@ -982,7 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Number<K0PerWmma>{},
I1,
I1,
Number<K1Value>{}>,
Number<AK1Value>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B0BlockTransferSrcScalarPerVector,
......@@ -1009,7 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
/*******************************************************************************/
// Gemm0
constexpr auto KPack = math::integer_least_multiple(K1Value, WmmaK);
constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK);
auto blockwise_gemm0 = BlockwiseGemmWMMA<
BlockSize,
......
......@@ -16,14 +16,15 @@ template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
device::TensorSpecialization TensorSpec>
static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
const std::vector<index_t>& gs_ms_ns_strides_vec)
__host__ __device__ static auto
MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
{
if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
{
throw std::runtime_error("wrong! dimension must match input lengths");
}
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
// {
// throw std::runtime_error("wrong! dimension must match input lengths");
// }
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
......@@ -143,21 +144,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
// A
//
static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
__host__ __device__ static auto MakeAGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec);
}
// TODO: rename to G_MRaw_KRaw
static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
__host__ __device__ static auto MakeAGridDescriptor_G_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
}
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
__host__ __device__ static auto MakeAGridDescriptor_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
......@@ -212,21 +216,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
// B (alias of B0)
//
static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
__host__ __device__ static auto MakeB0GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
__host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
}
static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
__host__ __device__ static auto MakeB0GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K(
......@@ -282,21 +289,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
// B1
//
static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
__host__ __device__ static auto MakeB1GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec);
}
// TODO: rename to G_NRaw_KRaw
static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
__host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
}
static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
const std::vector<index_t>& b1_gs_os_ns_strides_vec)
__host__ __device__ static auto MakeB1GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K(
......@@ -353,21 +363,24 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
//
// C
//
static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
__host__ __device__ static auto MakeCGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
__host__ __device__ static auto MakeCGridDescriptor_G_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
}
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
const std::vector<index_t>& c_gs_ms_os_strides_vec)
__host__ __device__ static auto MakeCGridDescriptor_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
......
#!/bin/bash
while getopts e: flag
do
case "${flag}" in
e) executable=${OPTARG};;
esac
done
echo "CK-NAVI31 Performance Test: MHA for AITemplate"
VERIFICATION=0
INITIALIZE=1
TIMING=1
ALL_TEST_CASE=0
SELF_ATTENTION=1
CROSS_ATTENTION=0
CAUSAL_MASK=0
# self attention with causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $SELF_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 1 ]; }; then
echo "Test launched: self attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 4096 4096 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 1024 1024 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 256 256 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
# cross attention with causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $CROSS_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 1 ]; }; then
echo "Test launched: cross attention with causal mask"
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 4096 64 40 40 2 8 0.158113881945610 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 1024 64 80 80 2 8 0.111803397536277 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 256 64 160 160 2 8 0.079056940972805 1 1
./bin/example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
# self attention without causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $SELF_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 0 ]; }; then
echo "Test launched: self attention without causal mask"
$executable $VERIFICATION $INITIALIZE $TIMING 4096 4096 64 64 2 5 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 1024 1024 64 64 2 10 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 256 256 64 64 2 20 0.125 1 1
$executable $VERIFICATION $INITIALIZE $TIMING 64 64 64 64 2 20 0.125 1 1
fi
# cross attention without causal mask
if [ $ALL_TEST_CASE -eq 1 ] || { [ $CROSS_ATTENTION -eq 1 ] && [ $CAUSAL_MASK -eq 0 ]; }; then
echo "Test launched: cross attention without causal mask"
$executable $VERIFICATION 1 $TIMING 4096 64 40 40 2 8 0.158113881945610 1 1
$executable $VERIFICATION 1 $TIMING 1024 64 80 80 2 8 0.111803397536277 1 1
$executable $VERIFICATION 1 $TIMING 256 64 160 160 2 8 0.079056940972805 1 1
$executable $VERIFICATION 1 $TIMING 64 64 160 160 2 8 0.079056940972805 1 1
fi
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment