Commit 610e0ab0 authored by guangzlu's avatar guangzlu
Browse files

updated dropout for bwd pt4 & pt5

parent 8edf2a72
......@@ -14,6 +14,7 @@ add_example_executable(example_batched_multihead_attention_backward_v2 batched_m
add_example_executable(example_batched_multihead_attention_backward_v3 batched_multihead_attention_backward_v3.cpp)
add_example_executable(example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train batched_multihead_attention_train.cpp)
add_example_executable(example_batched_multihead_attention_train_v5 batched_multihead_attention_train_v5.cpp)
add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
|-----------------|
Gemm0
|-------------------------------------|
Gemm1
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v5.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using InputDataType = F16;
using OutputDataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16; // INT32
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
// DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
InputDataType,
InputDataType,
InputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128,
128,
32,
32,
32,
8,
8,
2,
32,
32,
4,
1,
1,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 64, 1, 4>,
CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpec,
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
InputDataType,
InputDataType,
InputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
64,
128,
64,
64,
32,
8,
8,
2,
32,
32,
2,
1,
2,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
2,
S<1, 32, 1, 8>,
CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpec,
Deterministic>;
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// InputDataType,
// OutputDataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec,
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
InputDataType,
InputDataType,
InputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
InputDataType,
OutputDataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
64,
128,
128,
128,
32,
8,
8,
2,
32,
32,
2,
1,
4,
1,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1,
4,
S<1, 32, 1, 8>,
CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpec,
Deterministic>;
#endif
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
InputDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
InputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, InputDataType, InputDataType>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstanceFWD::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t N = 500; // 512
ck::index_t M = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4; // 54
ck::index_t G1 = 6; // 16
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
std::cout << "p_drop: " << p_drop << std::endl;
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_fwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_bwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_fwd_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os_device_result.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms_device_result.mDesc << std::endl;
z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method)
{
case 0: break;
case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_fwd_device_buf(sizeof(ZDataType) * z_fwd_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_bwd_device_buf(sizeof(ZDataType) * z_bwd_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) *
y_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) *
qgrad_gs_ms_ks_device_result.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) *
kgrad_gs_ns_ks_device_result.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) *
vgrad_gs_os_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * ygrad_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_fwd_device_buf.ToDevice(z_fwd_gs_ms_ns.mData.data());
z_bwd_device_buf.ToDevice(z_bwd_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
auto gemm_fwd = DeviceGemmInstanceFWD{};
auto invoker_fwd = gemm_fwd.MakeInvoker();
auto gemm_bwd = DeviceGemmInstanceBWD{};
auto invoker_bwd = gemm_bwd.MakeInvoker();
if(time_kernel)
{
auto argument_fwd = gemm_fwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
if(!gemm_fwd.IsSupportedArgument(argument_fwd))
{
std::cout << gemm_fwd.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_fwd = invoker_fwd.Run(argument_fwd, StreamConfig{nullptr, true});
std::size_t flop_fwd = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype_fwd =
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O) *
BatchCount;
float tflops_fwd = static_cast<float>(flop_fwd) / 1.E9 / ave_time_fwd;
float gb_per_sec_fwd = num_btype_fwd / 1.E6 / ave_time_fwd;
std::cout << "FWD Perf: " << ave_time_fwd << " ms, " << tflops_fwd << " TFlops, "
<< gb_per_sec_fwd << " GB/s, " << gemm_fwd.GetTypeString() << std::endl;
// not need output z matrix
auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
// vgrad_device_buf.SetZero();
float ave_time_bwd = invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, true});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std::size_t flop_bwd = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype_bwd =
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops_bwd = static_cast<float>(flop_bwd) / 1.E9 / ave_time_bwd;
float gb_per_sec_bwd = num_btype_bwd / 1.E6 / ave_time_bwd;
std::cout << "BWD Perf: " << ave_time_bwd << " ms, " << tflops_bwd << " TFlops, "
<< gb_per_sec_bwd << " GB/s, " << gemm_bwd.GetTypeString() << std::endl;
}
bool pass = true;
if(do_verification)
{
Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_fwd_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_bwd_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
// get kernel output matrixes
{
auto argument_fwd = gemm_fwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the
// number of elements on a thread
if(!gemm_fwd.IsSupportedArgument(argument_fwd))
{
std::cout << gemm_fwd.GetTypeString() << " does not support this problem"
<< std::endl;
return 0;
}
invoker_fwd.Run(argument_fwd, StreamConfig{nullptr, false});
// copy fwd matirx data form device
z_fwd_device_buf.FromDevice(z_fwd_gs_ms_ns.mData.data());
y_device_buf.FromDevice(y_gs_ms_os_device_result.mData.data());
lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
// std::cout << "z_fwd_gs_ms_ns ref:\n" << z_fwd_gs_ms_ns;
std::ofstream fwd_file("./z_fwd_matrix_txt");
fwd_file << z_fwd_gs_ms_ns << std::endl;
qgrad_device_buf.SetZero();
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm_bwd.IsSupportedArgument(argument_bwd))
{
std::cout << gemm_bwd.GetTypeString() << " does not support this problem"
<< std::endl;
return 0;
}
invoker_bwd.Run(argument_bwd, StreamConfig{nullptr, false});
// copy bwd matirx data form device
z_bwd_device_buf.FromDevice(z_bwd_gs_ms_ns.mData.data());
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// std::cout << "z_bwd_gs_ms_ns ref:\n" << z_bwd_gs_ms_ns;
std::ofstream bwd_file("./z_bwd_matrix_txt");
bwd_file << z_bwd_gs_ms_ns << std::endl;
}
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_fwd_g_m_n,
p_dropout_in_16bits,
rp_dropout);
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
z_bwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_bwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
{
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
}
#endif
// Gradients
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
}
#endif
// dP = dP_dropout x Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_bwd_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
}
self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
});
#if PRINT_HOST
{
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "y_g_m_o ref:\n" << y_g_m_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
#endif
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
#endif
// dQ = alpha * dS * K
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
}
#endif
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
}
#endif
// permute
y_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = y_g_m_o(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = lse_g_m(g, idx[2]);
});
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<InputDataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
std::cout << "Checking z:\n";
pass &= ck::utils::check_err(z_fwd_gs_ms_ns.mData, z_bwd_gs_ms_ns.mData, 1);
std::cout << "Checking y:\n";
pass &= ck::utils::check_err(
y_gs_ms_os_device_result.mData, y_gs_ms_os_host_result.mData, "error", rtol, atol);
std::cout << "Checking lse:\n";
pass &= ck::utils::check_err(
lse_gs_ms_device_result.mData, lse_gs_ms_host_result.mData, "error", rtol, atol);
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
}
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -122,6 +122,160 @@ struct BlockwiseDropout
});
}
template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwd(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
index_t MRaw)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutAttnBwdSaveZ(CThreadBuffer& in_thread_buf,
ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf,
index_t MRaw)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8 * MRaw);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8 * MRaw;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) =
execute_dropout(tmp[tmp_index] <= p_dropout_16bits, in_thread_buf(offset));
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropoutWithZ(CThreadBuffer& in_thread_buf,
ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = execute_dropout(z_thread_buf(offset) <= p_dropout_16bits,
in_thread_buf(offset));
tmp_index = tmp_index + 1;
// if(get_thread_global_1d_id()==0){
// printf("z at %d is %u \n", tmp_index, z_thread_buf(offset));
//}
});
});
}
// get raw z matrix with random number for shuffle
template <typename ZThreadBuffer>
__host__ __device__ void GenerateZMatrixAttnFwd(ck::philox& ph,
index_t element_global_1d_id,
ZThreadBuffer& z_thread_buf)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr int tmp_size = MRepeat * KRepeat;
int philox_calls = tmp_size / 4;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_4x16((tmp + i * 4), element_global_1d_id + i * 8);
}
// ushort tmp_id[tmp_size];
// for(int i = 0; i < philox_calls; i++)
//{
// for(int j = 0; j < 4; j++)
// {
// tmp_id[i * 4 + j] = element_global_1d_id + i * 8;
// }
//}
block_sync_lds();
int tmp_index = 0;
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
z_thread_buf(offset) = tmp[tmp_index];
tmp_index = tmp_index + 1;
});
});
}
ushort p_dropout_16bits;
DataType p_dropout_rescale;
};
......
......@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt4.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -39,7 +39,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
......@@ -71,7 +71,7 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
......@@ -85,7 +85,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset)
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -144,6 +146,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
i);
}
}
......@@ -176,6 +181,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
0);
}
#else
......@@ -278,6 +286,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
......@@ -808,8 +826,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print();
}
......@@ -869,8 +887,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -933,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
......@@ -967,7 +985,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
......@@ -979,11 +997,24 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_);
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
// if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
// {
// ave_time = launch_kernel(integral_constant<bool, true>{});
// }
// else
// {
// ave_time = launch_kernel(integral_constant<bool, false>{});
// }
ave_time = launch_kernel(integral_constant<bool, false>{});
#endif
return ave_time;
}
......@@ -1003,6 +1034,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static bool IsSupportedArgument(const Argument& arg)
{
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
......
......@@ -15,7 +15,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt7.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt5.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -38,7 +38,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M,
......@@ -70,8 +70,8 @@ __global__ void
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
......@@ -84,7 +84,9 @@ __global__ void
const C0MatrixMask c0_matrix_mask,
const float p_drop,
const unsigned long long seed,
const unsigned long long offset)
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -134,7 +136,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -143,6 +145,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
i);
}
}
......@@ -166,7 +171,7 @@ __global__ void
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
......@@ -175,6 +180,9 @@ __global__ void
c0_matrix_mask,
p_drop,
ph,
g_idx,
MRaw,
NRaw,
0);
}
#else
......@@ -284,6 +292,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{};
......@@ -821,8 +839,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(z_grid_desc_m_n_);
// Print();
}
......@@ -882,8 +900,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -950,7 +968,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M,
......@@ -984,7 +1002,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_,
......@@ -996,11 +1014,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg.c0_matrix_mask_,
arg.p_drop_,
arg.seed_,
arg.offset_);
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
#if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
......@@ -1009,7 +1030,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
#endif
return ave_time;
}
......@@ -1029,6 +1050,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static bool IsSupportedArgument(const Argument& arg)
{
#if 0
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename ZDataType,
typename FloatLSE,
typename GemmAccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename LSEGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop,
bool IsDropout,
bool IsLseStoring,
bool Deterministic>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_multiheadattention_forward_xdl_cshuffle(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits,
const GemmAccDataType p_dropout_rescale,
const unsigned long long seed,
const unsigned long long offset,
const index_t MRaw,
const index_t NRaw)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
if constexpr(Deterministic)
{
for(index_t i = 0; i < mblock; i++)
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
g_idx,
MRaw,
NRaw,
0);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_b1_grid;
ignore = p_c_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = b1_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
typename Acc1BiasDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
TensorSpecialization ASpec,
TensorSpecialization BSpec,
TensorSpecialization B1Spec,
TensorSpecialization CSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
ADataType,
BDataType,
B1DataType,
CDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
MaskingSpec>
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
#if 0
// TODO ANT: use alias
static constexpr index_t NumDimGemm0M = NumDimM;
static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
GemmSpec,
ASpec,
BSpec,
B1Spec,
CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
}
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
{
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{});
}
static auto
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{});
}
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(lse_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return lse_grid_desc_mraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate()
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
return MaskDisabledPredicate{};
}
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{
return MaskOutUpperTrianglePredicate{};
}
}
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
private:
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
ZDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
ZGridDesc_M_N,
LSEGridDesc_M,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(
const ADataType* p_a_grid,
const BDataType* p_b_grid,
const B1DataType* p_b1_grid,
CDataType* p_c_grid,
ZDataType* p_z_grid,
LSEDataType* p_lse_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
p_z_grid_{p_z_grid},
p_lse_grid_{p_lse_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_g_n_k_{
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)},
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]},
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1],
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1],
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{
// TODO ANT: implement bias addition
ignore = p_acc0_biases;
ignore = p_acc1_biases;
ignore = acc0_biases_gs_ms_ns_lengths;
ignore = acc0_biases_gs_ms_ns_strides;
ignore = acc1_biases_gs_ms_gemm1ns_lengths;
ignore = acc1_biases_gs_ms_gemm1ns_strides;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
b1_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_);
}
is_dropout_ = p_dropout > 0.0; //
p_dropout_ = 1.f - p_dropout;
p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_ = 1.f / p_dropout_;
p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
}
void Print() const
{
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
<< a_grid_desc_g_m_k_.GetLength(I1) << ", "
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n';
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n';
}
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_;
std::vector<index_t> b1_nz_kz_strides_;
std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
ushort p_dropout_in_16bits_;
GemmAccDataType p_dropout_rescale_;
unsigned long long seed_;
unsigned long long offset_;
bool is_dropout_;
bool is_lse_storing_ = true;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
(Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0;
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
ZDataType,
LSEDataType,
GemmAccDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_,
is_dropout_,
is_lse_storing_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_b1_grid_,
arg.p_c_grid_,
arg.p_z_grid_,
arg.p_lse_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_,
arg.p_dropout_in_16bits_,
arg.p_dropout_rescale_,
arg.seed_,
arg.offset_,
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else
{
if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
}
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
#if DEBUG_LOG
arg.Print();
#endif
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
// TODO ANT: Check if tensor specialization & strides mismatch
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
// Check scalar per vector requirement
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
const auto c_extent_lowest = Gemm1NzRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
const auto b_stride_lowest =
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest =
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0];
const auto c_stride_lowest =
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
c_stride_lowest == 1))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const ADataType* p_a,
const BDataType* p_b,
const B1DataType* p_b1,
CDataType* p_c,
ZDataType* p_z,
LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
p_b1,
p_c,
p_z,
p_lse,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_dropout,
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const void* p_b1,
void* p_c,
void* p_z,
void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_dropout,
seeds);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -113,23 +113,24 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n)
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
......@@ -396,8 +397,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// Q / K / V / dY
struct GemmBlockwiseCopy
......@@ -1234,8 +1235,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_O0_N_O1& v_grid_desc_o0_n_o1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
......@@ -1245,6 +1246,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
......@@ -1578,70 +1582,69 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2)); // registerNum
m2, // MGroupNum
m3, // MInputNum
m4, // registerNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// z matrix global desc
// ignore = p_z_tmp_grid;
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_n, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
......@@ -1795,6 +1798,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
lse_grid_desc_mb_m0_m1_m2_m3_m4, make_multi_index(1, 0, 0, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
continue;
}
......@@ -1950,35 +1958,167 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// save z to global
if(p_z_grid)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, MRaw);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
//// P_dropped
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
// decltype(z_tenor_buffer),
// true,
// decltype(n0),
// decltype(i)>(
// s_slash_p_thread_buf, ph, z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
}
else
{
ignore = z_grid_buf;
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, MRaw);
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2161,8 +2301,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
qgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
qgrad_grid_desc_m0_o0_m1_o1_m2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step M
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
lse_thread_copy_global_to_vgpr.MoveSrcSliceWindow(lse_grid_desc_mb_m0_m1_m2_m3_m4,
make_multi_index(1, 0, 0, 0, 0, 0));
y_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
......
......@@ -127,24 +127,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n)
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
constexpr auto M3 = mfma.num_groups_per_blk;
constexpr auto M4 = mfma.num_input_blks;
constexpr auto M5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, M3, M4, M5)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
make_tuple(Sequence<0, 2, 4, 6, 7, 8>{}, Sequence<1, 3, 5, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{
......@@ -467,8 +469,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(KGridDesc_N_K{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcc)
struct Gemm0
......@@ -1183,8 +1185,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
......@@ -1194,6 +1196,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const C0MatrixMask& c0_matrix_mask,
const float p_drop,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_n)
{
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
......@@ -1558,70 +1563,68 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2)); // registerNum
m2, // MGroupNum
m3, // MInputNum
m4, // registerNum
n2)); // NPerXdl
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
// z matrix global desc
// ignore = p_z_tmp_grid;
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( // tmp buffer for shuffle
p_z_grid,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
I1, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_n, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_thread_copy_vgpr_to_global =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
m3, // NGroupNum
m4, // NInputNum
n2>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(0, // MBlockId
block_work_idx_n, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
0, // MPerXdl
wave_m_n_id[I0], // group
0, // NInputIndex
wave_m_n_id[I1]),
tensor_operation::element_wise::PassThrough{}};
//
// set up Y dot dY
......@@ -1743,8 +1746,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_mblock_mperblock_oblock_operblock, make_multi_index(1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
continue;
}
......@@ -1891,35 +1894,144 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// save z to global
if(p_z_grid)
{
// P_dropped
static_for<0, n0, 1>{}([&](auto i) {
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true,
decltype(n0),
decltype(i)>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// if(get_block_1d_id() == 0 && get_thread_local_1d_id()==64){
// printf("global_elem_id is %d \n", global_elem_id);
//}
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// if(get_thread_global_1d_id() == 0){
// printf("Acc0TileIterator::GetNumOfAccess() is %d \n",
// Acc0TileIterator::GetNumOfAccess()); printf("n0.value is %d \n", n0.value);
// printf("id_step is %d \n", id_step);
//}
// dropout
// z_tenor_buffer_tmp -> z_grid_buf_tmp -> shuffle -> z_tenor_buffer -> z_grid_buf
blockwise_dropout.template ApplyDropoutAttnBwdSaveZ<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, global_elem_id, z_tenor_buffer, MRaw);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
z_grid_buf);
}
else
{
ignore = z_grid_buf;
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto M3 = c_block_lengths[I5];
constexpr auto M4 = c_block_lengths[I6];
constexpr auto N2 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2, M3, M4)),
make_unmerge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
//}
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
// if(get_thread_global_1d_id()==0){
// printf("tid 0 m_global & n_global is %d & %d \n", m_global , n_global);
// }
// if(get_thread_global_1d_id()==32){
// printf("tid 32 m_global & n_global is %d & %d \n", m_global , n_global);
// }
auto global_elem_id_raw =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
auto global_elem_id =
(global_elem_id_raw % 4) * MRaw + int(global_elem_id_raw / 4) * 4;
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
blockwise_dropout
.template ApplyDropoutAttnBwd<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph, global_elem_id, MRaw);
}
block_sync_lds(); // wait for gemm1 LDS read
......@@ -2183,8 +2295,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_multi_index(1, 0, 0, 0));
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
make_multi_index(1, 0, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm0_m_block_outer_index < num_gemm0_m_block_outer_loop); // end j loop
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename FloatAB,
typename ZDataType,
typename FloatGemm,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename FloatLSE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
typename ZGridDesc_M_N,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in gridwise copy
__host__ __device__ static constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(FloatGemm);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(FloatGemm);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
const index_t z_block_bytes_end =
SharedMemTrait::z_shuffle_block_space_size * sizeof(ZDataType);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
softmax_bytes_end,
c_block_bytes_end,
z_block_bytes_end);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
// if(Gemm1N != K)
// {
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false;
// }
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{}));
return lse_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
static constexpr auto z_shuffle_block_space_size = MPerBlock * NPerBlock;
};
template <bool HasMainKBlockLoop,
bool IsDropout,
bool IsLseStoring,
typename Block2CTileMap,
typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AccElementwiseOperation& acc_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale,
ck::philox& ph,
const index_t g_idx,
const index_t MRaw,
const index_t NRaw,
const index_t block_idx_m)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
const index_t gemm1_n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
//
// set up Gemm0
//
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatGemm,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatGemm,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1, BK1),
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatGemm,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>{}; // TransposeC
auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::a_block_space_offset,
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b_block_space_offset,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto a_block_reset_copy_step =
make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
const auto b_block_reset_copy_step =
make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
//
// set up Gemm1
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr auto AccN3 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
constexpr auto AccM2 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I4);
constexpr auto A1ThreadSlice_K0_M_K1 =
make_tuple(Number<Gemm1KPerBlock / n4 / AccN3>{}, Number<m0 * m1 * m2>{}, Number<n4>{});
constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
A1ThreadSlice_K0_M_K1,
make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A1 matrix blockwise copy
auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
FloatGemm,
decltype(acc_thread_desc_k0_m_k1),
decltype(a1_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
Sequence<A1ThreadSliceK0, A1ThreadSliceM, A1ThreadSliceK1>,
Sequence<1, 0, 2>,
2,
n4>{tensor_operation::element_wise::PassThrough{}};
// B1 matrix blockwise copy
auto b1_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatGemm,
decltype(b1_grid_desc_bk0_n_bk1),
decltype(b1_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>(
b1_grid_desc_bk0_n_bk1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
b1_element_op,
b1_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto a1_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemm>(
a1_thread_desc_k0_m_k1.GetElementSpaceSize());
// reuse LDS space for gemm0's b_block_buf
auto b1_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemm*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatGemm, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,
FloatGemm,
FloatGemmAcc,
decltype(a1_thread_desc_k0_m_k1),
decltype(b1_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
Gemm1KPack,
true, // TransposeC
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatGemm, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
//
// Blockwise softmax
//
auto workspace_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
SharedMemTrait::reduction_space_size_aligned);
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
// get acc0 thread map
constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, p_dropout_rescale};
const index_t num_gemm1_k_block_outer_loop =
b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
// Initialize C
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true>
c_thread_buf;
c_thread_buf.Clear();
// Initialize running sum and max of exponentiating row vectors
using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
running_sum = 0;
running_sum_new = 0;
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
constexpr auto lse_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
FloatLSE,
decltype(lse_thread_desc_mblock_mrepeat_mwave_mperxdl),
decltype(lse_grid_desc_mblock_mrepeat_mwave_mperxdl),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, 1, 1, 1>,
Sequence<0, 1, 2, 3>,
3,
1,
InMemoryDataOperationEnum::Set,
1,
false>{lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx_m, // mblock
0, // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4]), // mperxdl
ck::tensor_operation::element_wise::PassThrough{}};
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
// z is random number matrix for dropout verify
//
// z vgpr copy to global
//
// z matrix threadwise desc
// if(get_thread_global_1d_id()==0){
// printf("m2 is %d \n",m2.value);
// printf("n2 is %d \n",n2.value);
// printf("n3 is %d \n",n3.value);
// printf("n4 is %d \n",n4.value);
//}
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
constexpr auto z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 = // for blockwise copy
make_naive_tensor_descriptor_packed(make_tuple(m0, //
n0, //
m1, //
n1, //
m2, // m0 1
n2, // n0 4
n3, // n1 1
n4, // m1 4
I1)); // n2 1
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
constexpr auto z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// constexpr auto z_block_lengths = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLengths();
constexpr auto zM0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto zN0 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto zM1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto zN1 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto zM2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto zN2 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto zN3 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto zN4 = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4 =
transform_tensor_descriptor(
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_pass_through_transform(zM0),
make_pass_through_transform(zN0),
make_pass_through_transform(zM1),
make_pass_through_transform(zN1),
make_unmerge_transform(make_tuple(Number<zM2.value / zN4.value>{}, zN4)),
make_pass_through_transform(zN2),
make_pass_through_transform(zN3),
make_pass_through_transform(zN4)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 7>{},
Sequence<5>{},
Sequence<6>{},
Sequence<8>{}));
// ignore = z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4;
// ignore = z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4;
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize(),
true>
z_tensor_buffer; // z buffer after shuffle
z_tensor_buffer.Clear();
// z matrix global desc
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
auto z_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ZDataType*>(p_shared),
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize());
// if(get_thread_global_1d_id()==0){
// printf("z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize() is %ld \n",
// z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize().value);
// printf("z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n", z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// printf("z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize() is %ld
// \n",z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize().value);
// printf("z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize() is %ld
// \n",z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4.GetElementSpaceSize().value);
// }
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
// if(get_block_1d_id()==0){
// if(get_thread_local_1d_id()==0){
// printf("tid = 0 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==1){
// printf("tid = 1 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==2){
// printf("tid = 2 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==3){
// printf("tid = 3 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==32){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
// if(get_thread_local_1d_id()==64){
// printf("tid = 32 , wave_m_n_id[I0] & wave_m_n_id[I1] is %d & %d
// \n",wave_m_n_id[I0], wave_m_n_id[I1]);
// }
//}
auto z_tmp_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
auto z_shuffle_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
ZDataType,
ushort,
decltype(z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
decltype(z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4),
Sequence<m0, n0, m1, n1, m2, n2, n3, n4, I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
8,
1,
1,
true>{z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_multi_index(0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
int(wave_m_n_id[I1] / 4), // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0,
wave_m_n_id[I1] % 4)};
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ZDataType,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
if constexpr(Deterministic)
{
block_sync_lds();
}
do
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
continue;
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
acc_thread_buf,
num_k_block_main_loop);
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]);
}
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// softmax
SoftmaxBuf& max = blockwise_softmax.max_value_buf;
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
blockwise_softmax.Run(acc_thread_buf, workspace_buf);
if constexpr(IsDropout) // dropout
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
auto acc0_thread_idx = Acc0TileIterator::GetIndex(I0) + acc0_thread_origin;
auto m_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local = block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
auto global_elem_id =
MRaw * NRaw * g_idx + m_global * NRaw + n_global; // unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
blockwise_dropout.template GenerateZMatrixAttnFwd<decltype(z_tensor_buffer)>(
ph, global_elem_id, z_tensor_buffer);
z_tmp_thread_copy_vgpr_to_lds.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
z_block_buf);
z_shuffle_thread_copy_lds_to_vgpr.Run(
z_block_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
z_block_buf,
z_thread_shuffle_desc_m0_n0_m1_n1_m2_n2_n3_m3_n4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer);
blockwise_dropout.template ApplyDropoutWithZ<decltype(acc_thread_buf),
decltype(z_tensor_buffer),
false>(acc_thread_buf,
z_tensor_buffer);
if(p_z_grid)
{
// ignore = z_tensor_buffer;
z_thread_copy_vgpr_to_global.Run(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tensor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
}
}
// TODO: may convert to log domain
running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// Initialize acc1
acc1_thread_buf.Clear();
// preload data into LDS
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
b1_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
});
}
// tail
{
a1_blockwise_copy.Run(
acc_thread_desc_k0_m_k1,
make_tuple(
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
}
} // end gemm1
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
(is_same_v<FloatGemm, bhalf_t>)&&MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
b_block_reset_copy_step); // rewind K and step N
// update before next j iteration
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// Calculate max + ln(sum) and write out
if constexpr(IsLseStoring)
{
static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
if(get_warp_local_1d_id() < AccM2)
{
static_for<0, MXdlPerWave, 1>{}([&](auto I) {
// copy from VGPR to Global
lse_thread_copy_vgpr_to_global.Run(lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, Number<I>{}, I0, I0),
lse_thread_buf,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf);
lse_thread_copy_vgpr_to_global.MoveDstSliceWindow(
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
});
}
}
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
tensor_operation::element_wise::PassThrough{}};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, Gemm1NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
}
}
};
} // namespace ck
......@@ -84,6 +84,17 @@ class philox
out_tmp[3] = tmp_ph.w;
}
__device__ void get_random_4x16(ushort* out, const unsigned long long subsequence)
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
out[0] = static_cast<ushort>(tmp_ph.x);
out[1] = static_cast<ushort>(tmp_ph.y);
out[2] = static_cast<ushort>(tmp_ph.z);
out[3] = static_cast<ushort>(tmp_ph.w);
}
private:
struct ull2
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment