Commit 883c060a authored by guangzlu's avatar guangzlu
Browse files

bwd qloop v1 passed

parent 72498367
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <vector>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp"
......@@ -14,13 +16,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
......@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
......@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
......@@ -86,70 +73,10 @@ struct SimpleDeviceMem
void* p_mem_;
};
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int main(int argc, char* argv[])
{
int init_method = 1;
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
......@@ -164,53 +91,12 @@ int main(int argc, char* argv[])
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl;
......@@ -264,120 +150,19 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method)
{
case 0: break;
case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward<2,
SimpleDeviceMem q_device_buf(sizeof(InputDataType) * G0 * G1 * M * K);
SimpleDeviceMem k_device_buf(sizeof(InputDataType) * G0 * G1 * N * K);
SimpleDeviceMem z_device_buf(sizeof(ZDataType) * G0 * G1 * M * N);
SimpleDeviceMem v_device_buf(sizeof(InputDataType) * G0 * G1 * O * N);
SimpleDeviceMem y_device_buf(sizeof(InputDataType) * G0 * G1 * M * O);
SimpleDeviceMem lse_device_buf(sizeof(LSEDataType) * G0 * G1 * M);
SimpleDeviceMem qgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * M * K);
SimpleDeviceMem kgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * N * K);
SimpleDeviceMem vgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * O * N);
SimpleDeviceMem ygrad_device_buf(sizeof(InputDataType) * G0 * G1 * M * O);
using DeviceOp =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
......@@ -388,11 +173,11 @@ int main(int argc, char* argv[])
LSEDataType,
ck::Tuple<>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
MaskingSpec>;
// get device op instances
......@@ -409,13 +194,12 @@ int main(int argc, char* argv[])
// profile device op instances
std::cout << "Run all instances and do timing" << std::endl;
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(q_device_buf.GetDeviceBuffer(),
auto argument_ptr = op_ptr->MakeArgumentPointer(
q_device_buf.GetDeviceBuffer(),
k_device_buf.GetDeviceBuffer(),
nullptr, // set to nullptr
v_device_buf.GetDeviceBuffer(),
......@@ -450,13 +234,11 @@ int main(int argc, char* argv[])
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
......@@ -496,7 +278,8 @@ int main(int argc, char* argv[])
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(q_device_buf.GetDeviceBuffer(),
auto argument_ptr = op_ptr->MakeArgumentPointer(
q_device_buf.GetDeviceBuffer(),
k_device_buf.GetDeviceBuffer(),
nullptr, // set to nullptr
v_device_buf.GetDeviceBuffer(),
......
......@@ -18,8 +18,9 @@ namespace device {
namespace instance {
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
1,
......@@ -34,12 +35,11 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances);
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
......@@ -55,12 +55,11 @@ void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instance
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
MaskingSpecialization::MaskDisabled>>>& instances);
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
......@@ -76,12 +75,11 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances);
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
......@@ -97,8 +95,7 @@ void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instan
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
MaskingSpecialization::MaskDisabled>>>& instances);
template <typename InputDataType,
typename OutputDataType,
......@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
op_ptrs);
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(op_ptrs);
}
else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{
......
......@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
......@@ -50,8 +51,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
std::tuple<
using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple<
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
......@@ -65,8 +65,8 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
>;
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
......@@ -82,11 +82,9 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances)
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
{
add_device_operation_instances(
instances,
add_device_operation_instances(instances,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
......@@ -97,8 +95,8 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
}
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
......@@ -114,11 +112,9 @@ void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instan
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
MaskingSpecialization::MaskDisabled>>>& instances)
{
add_device_operation_instances(
instances,
add_device_operation_instances(instances,
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
2,
1,
......
......@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default;
// static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
......@@ -50,8 +51,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO,
MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
std::tuple<
using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
......@@ -65,8 +65,8 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
>;
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
......@@ -82,11 +82,9 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>&
instances)
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
{
add_device_operation_instances(
instances,
add_device_operation_instances(instances,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
......@@ -97,8 +95,8 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
}
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
......@@ -114,11 +112,9 @@ void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instance
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances)
MaskingSpecialization::MaskDisabled>>>& instances)
{
add_device_operation_instances(
instances,
add_device_operation_instances(instances,
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
2,
1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment