Commit 883c060a authored by guangzlu's avatar guangzlu
Browse files

bwd qloop v1 passed

parent 72498367
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <vector> #include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_mha_bwd_qloop.hpp"
...@@ -14,13 +16,6 @@ ...@@ -14,13 +16,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -51,11 +46,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8; static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft; ck::tensor_operation::device::MaskingSpecialization::MaskUpperTriangleFromTopLeft;
...@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec = ...@@ -64,12 +57,6 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif #endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = false;
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
SimpleDeviceMem() = delete; SimpleDeviceMem() = delete;
...@@ -86,70 +73,10 @@ struct SimpleDeviceMem ...@@ -86,70 +73,10 @@ struct SimpleDeviceMem
void* p_mem_; void* p_mem_;
}; };
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ZDataType p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
auto M = s_g_m_n.GetLengths()[1];
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(M, N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
int init_method = 1;
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
...@@ -164,53 +91,12 @@ int main(int argc, char* argv[]) ...@@ -164,53 +91,12 @@ int main(int argc, char* argv[])
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
p_drop = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0)); ZDataType p_dropout_in_16bits = ZDataType(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
std::cout << "time_kernel: " << time_kernel << std::endl;
std::cout << "M: " << M << std::endl; std::cout << "M: " << M << std::endl;
std::cout << "N: " << N << std::endl; std::cout << "N: " << N << std::endl;
std::cout << "K: " << K << std::endl; std::cout << "K: " << K << std::endl;
...@@ -264,137 +150,36 @@ int main(int argc, char* argv[]) ...@@ -264,137 +150,36 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); SimpleDeviceMem q_device_buf(sizeof(InputDataType) * G0 * G1 * M * K);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); SimpleDeviceMem k_device_buf(sizeof(InputDataType) * G0 * G1 * N * K);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); SimpleDeviceMem z_device_buf(sizeof(ZDataType) * G0 * G1 * M * N);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); SimpleDeviceMem v_device_buf(sizeof(InputDataType) * G0 * G1 * O * N);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); SimpleDeviceMem y_device_buf(sizeof(InputDataType) * G0 * G1 * M * O);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); SimpleDeviceMem lse_device_buf(sizeof(LSEDataType) * G0 * G1 * M);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); SimpleDeviceMem qgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * M * K);
SimpleDeviceMem kgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * N * K);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; SimpleDeviceMem vgrad_device_buf(sizeof(OutputDataType) * G0 * G1 * O * N);
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; SimpleDeviceMem ygrad_device_buf(sizeof(InputDataType) * G0 * G1 * M * O);
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; using DeviceOp =
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward<2,
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; 1,
1,
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); 1,
switch(init_method) 1,
{ InputDataType,
case 0: break; OutputDataType,
case 1: ZDataType,
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); LSEDataType,
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ck::Tuple<>,
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ck::Tuple<>,
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); QKVElementOp,
break; QKVElementOp,
case 2: Scale,
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); QKVElementOp,
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); YElementOp,
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5}); MaskingSpec>;
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
using DeviceOp = ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
1,
InputDataType,
OutputDataType,
ZDataType,
LSEDataType,
ck::Tuple<>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances(); DeviceOp>::GetInstances();
...@@ -409,54 +194,51 @@ int main(int argc, char* argv[]) ...@@ -409,54 +194,51 @@ int main(int argc, char* argv[])
// profile device op instances // profile device op instances
std::cout << "Run all instances and do timing" << std::endl; std::cout << "Run all instances and do timing" << std::endl;
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(q_device_buf.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(
k_device_buf.GetDeviceBuffer(), q_device_buf.GetDeviceBuffer(),
nullptr, // set to nullptr k_device_buf.GetDeviceBuffer(),
v_device_buf.GetDeviceBuffer(), nullptr, // set to nullptr
y_device_buf.GetDeviceBuffer(), v_device_buf.GetDeviceBuffer(),
lse_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
ygrad_device_buf.GetDeviceBuffer(), lse_device_buf.GetDeviceBuffer(),
qgrad_device_buf.GetDeviceBuffer(), ygrad_device_buf.GetDeviceBuffer(),
kgrad_device_buf.GetDeviceBuffer(), qgrad_device_buf.GetDeviceBuffer(),
vgrad_device_buf.GetDeviceBuffer(), kgrad_device_buf.GetDeviceBuffer(),
{}, // std::array<void*, 1> p_acc0_biases; vgrad_device_buf.GetDeviceBuffer(),
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc0_biases;
q_gs_ms_ks_lengths, {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_strides, q_gs_ms_ks_lengths,
k_gs_ns_ks_lengths, q_gs_ms_ks_strides,
k_gs_ns_ks_strides, k_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, k_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
v_gs_os_ns_lengths, z_gs_ms_ns_strides,
v_gs_os_ns_strides, v_gs_os_ns_lengths,
y_gs_ms_os_lengths, v_gs_os_ns_strides,
y_gs_ms_os_strides, y_gs_ms_os_lengths,
lse_gs_ms_lengths, y_gs_ms_os_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
QKVElementOp{}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, QKVElementOp{},
QKVElementOp{}, Scale{alpha},
YElementOp{}, QKVElementOp{},
p_drop, YElementOp{},
std::tuple<unsigned long long, unsigned long long>(seed, offset)); p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
...@@ -496,40 +278,41 @@ int main(int argc, char* argv[]) ...@@ -496,40 +278,41 @@ int main(int argc, char* argv[])
auto& op_ptr = op_ptrs[best_op_id]; auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(q_device_buf.GetDeviceBuffer(), auto argument_ptr = op_ptr->MakeArgumentPointer(
k_device_buf.GetDeviceBuffer(), q_device_buf.GetDeviceBuffer(),
nullptr, // set to nullptr k_device_buf.GetDeviceBuffer(),
v_device_buf.GetDeviceBuffer(), nullptr, // set to nullptr
y_device_buf.GetDeviceBuffer(), v_device_buf.GetDeviceBuffer(),
lse_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(),
ygrad_device_buf.GetDeviceBuffer(), lse_device_buf.GetDeviceBuffer(),
qgrad_device_buf.GetDeviceBuffer(), ygrad_device_buf.GetDeviceBuffer(),
kgrad_device_buf.GetDeviceBuffer(), qgrad_device_buf.GetDeviceBuffer(),
vgrad_device_buf.GetDeviceBuffer(), kgrad_device_buf.GetDeviceBuffer(),
{}, // std::array<void*, 1> p_acc0_biases; vgrad_device_buf.GetDeviceBuffer(),
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc0_biases;
q_gs_ms_ks_lengths, {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_strides, q_gs_ms_ks_lengths,
k_gs_ns_ks_lengths, q_gs_ms_ks_strides,
k_gs_ns_ks_strides, k_gs_ns_ks_lengths,
z_gs_ms_ns_lengths, k_gs_ns_ks_strides,
z_gs_ms_ns_strides, z_gs_ms_ns_lengths,
v_gs_os_ns_lengths, z_gs_ms_ns_strides,
v_gs_os_ns_strides, v_gs_os_ns_lengths,
y_gs_ms_os_lengths, v_gs_os_ns_strides,
y_gs_ms_os_strides, y_gs_ms_os_lengths,
lse_gs_ms_lengths, y_gs_ms_os_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
QKVElementOp{}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, QKVElementOp{},
QKVElementOp{}, Scale{alpha},
YElementOp{}, QKVElementOp{},
p_drop, YElementOp{},
std::tuple<unsigned long long, unsigned long long>(seed, offset)); p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
...@@ -18,15 +18,36 @@ namespace device { ...@@ -18,15 +18,36 @@ namespace device {
namespace instance { namespace instance {
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
1,
F16,
F16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
...@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( ...@@ -34,31 +55,29 @@ void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskDisabled>>>& instances);
instances);
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
1,
F16,
F16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
2,
1,
1,
1,
1,
BF16,
BF16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances);
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
...@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances ...@@ -68,7 +87,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
...@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances ...@@ -76,29 +95,7 @@ void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskDisabled>>>& instances);
instances);
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2,
1,
1,
1,
1,
BF16,
BF16,
unsigned short,
F32,
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
template <typename InputDataType, template <typename InputDataType,
typename OutputDataType, typename OutputDataType,
...@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory< ...@@ -151,8 +148,7 @@ struct DeviceOperationInstanceFactory<
{ {
if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft) if constexpr(MaskingSpec == MaskingSpecialization::MaskUpperTriangleFromTopLeft)
{ {
add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(op_ptrs);
op_ptrs);
} }
else if(MaskingSpec == MaskingSpecialization::MaskDisabled) else if(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
......
...@@ -32,10 +32,11 @@ using YElementOp = PassThrough; ...@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; // static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
...@@ -50,9 +51,8 @@ template <index_t NumDimG, ...@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = std::tuple<
std::tuple< // clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances = ...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic> // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on // clang-format on
>; >;
void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
DeviceBatchedMultiheadAttentionBackward<2, 2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_noncasual_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<std::unique_ptr<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1, 1,
BF16, BF16,
BF16, BF16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_bf16_bf16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskDisabled>{});
MaskingSpecialization::MaskDisabled>{});
} }
} // namespace instance } // namespace instance
......
...@@ -32,10 +32,11 @@ using YElementOp = PassThrough; ...@@ -32,10 +32,11 @@ using YElementOp = PassThrough;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
//static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
//static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; // static constexpr auto TensorDefault =
// ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
...@@ -50,9 +51,8 @@ template <index_t NumDimG, ...@@ -50,9 +51,8 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, index_t NumDimO,
MaskingSpecialization MaskingSpec> MaskingSpecialization MaskingSpec>
using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
std::tuple< // clang-format off
// clang-format off
// ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic| // ########################################################################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| InputDataType| OutputDataType| GemmDataType| ZDataType| LSEDataType| Acc0BiasDataType| Acc1BiasDataType| GemmAcc| CShuffle| A| B| Acc| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| Gemm2| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CShuffleBlockTransferScalarPerVector_NPerBlock| MaskingSpec| Deterministic|
// ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| | | |
// ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | | // ########################################################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| | | |
...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances = ...@@ -61,71 +61,67 @@ using device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>, // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>,
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic> // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, unsigned short, F32, Acc0BiasDataType, Acc1BiasDataType, F32, F32, QKVElementOp, QKVElementOp, Scale, QKVElementOp, YElementOp, GemmSpec, TensorSpecQ, TensorSpecK, TensorSpecV, TensorSpecY, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 4, 1, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 2, S<1, 32, 1, 8>, CShuffleBlockTransferScalarPerVector_NPerBlock, MaskingSpec, Deterministic>
// clang-format on // clang-format on
>; >;
void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_casual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<
DeviceBatchedMultiheadAttentionBackward<2, 2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& MaskingSpecialization::MaskUpperTriangleFromTopLeft>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
MaskingSpecialization::MaskUpperTriangleFromTopLeft>{});
} }
void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances( void add_device_batched_mha_bwd_qloop_noncasual_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector< std::vector<std::unique_ptr<
std::unique_ptr<DeviceBatchedMultiheadAttentionBackward<2, DeviceBatchedMultiheadAttentionBackward<2,
1, 1,
1, 1,
1, 1,
1, 1,
F16, F16,
F16, F16,
unsigned short, unsigned short,
F32, F32,
ck::Tuple<>, ck::Tuple<>,
ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough, PassThrough,
MaskingSpecialization::MaskDisabled>>>& MaskingSpecialization::MaskDisabled>>>& instances)
instances)
{ {
add_device_operation_instances( add_device_operation_instances(instances,
instances, device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances<
device_batched_mha_bwd_qloop_f16_f16_gmk_gnk_gno_gmo_instances< 2,
2, 1,
1, 1,
1, 1,
1, 1,
1, MaskingSpecialization::MaskDisabled>{});
MaskingSpecialization::MaskDisabled>{});
} }
} // namespace instance } // namespace instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment