Commit 010ed35f authored by danyao12's avatar danyao12
Browse files

merge from attn-bwd-develop

parents 272b7574 042e4b8c
...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -42,6 +44,7 @@ using B1DataType = BF16; ...@@ -42,6 +44,7 @@ using B1DataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = BF16; using CDataType = BF16;
using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -78,6 +81,7 @@ using DeviceGemmInstance = ...@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_batched_multihead_attention_forward.inc" #include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -42,6 +44,7 @@ using B1DataType = F16; ...@@ -42,6 +44,7 @@ using B1DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = F16;
using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -78,6 +81,7 @@ using DeviceGemmInstance = ...@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_batched_multihead_attention_forward.inc" #include "run_batched_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,12 +27,14 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -42,6 +44,7 @@ using B1DataType = BF16; ...@@ -42,6 +44,7 @@ using B1DataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = BF16; using CDataType = BF16;
using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -78,6 +81,7 @@ using DeviceGemmInstance = ...@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -98,8 +102,8 @@ using DeviceGemmInstance = ...@@ -98,8 +102,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -107,7 +111,7 @@ using DeviceGemmInstance = ...@@ -107,7 +111,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -157,6 +161,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp" #include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -102,8 +103,8 @@ using DeviceGemmInstance = ...@@ -102,8 +103,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -111,7 +112,7 @@ using DeviceGemmInstance = ...@@ -111,7 +112,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -161,6 +162,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
B1ElementOp, B1ElementOp,
CElementOp>; CElementOp>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ZDataType, ADataType, ADataType>;
#include "run_grouped_multihead_attention_forward.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -12,7 +12,7 @@ int run(int argc, char* argv[]) ...@@ -12,7 +12,7 @@ int run(int argc, char* argv[])
ck::index_t M = 1000; // 120 ck::index_t M = 1000; // 120
ck::index_t N = 1000; // 1000 ck::index_t N = 1000; // 1000
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 64;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
...@@ -25,6 +25,13 @@ int run(int argc, char* argv[]) ...@@ -25,6 +25,13 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.1;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -88,6 +95,12 @@ int run(int argc, char* argv[]) ...@@ -88,6 +95,12 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
...@@ -97,6 +110,7 @@ int run(int argc, char* argv[]) ...@@ -97,6 +110,7 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -104,8 +118,11 @@ int run(int argc, char* argv[]) ...@@ -104,8 +118,11 @@ int run(int argc, char* argv[])
std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl;
std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl;
std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms_host_result.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms_host_result.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<ZDataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -135,6 +152,7 @@ int run(int argc, char* argv[]) ...@@ -135,6 +152,7 @@ int run(int argc, char* argv[])
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * DeviceMem c_device_buf(sizeof(CDataType) *
c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); c_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize()); lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
...@@ -157,6 +175,7 @@ int run(int argc, char* argv[]) ...@@ -157,6 +175,7 @@ int run(int argc, char* argv[])
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
...@@ -168,6 +187,8 @@ int run(int argc, char* argv[]) ...@@ -168,6 +187,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
...@@ -178,8 +199,8 @@ int run(int argc, char* argv[]) ...@@ -178,8 +199,8 @@ int run(int argc, char* argv[])
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
0, // dropout ratio p_drop, // dropout ratio
{0, 64}); // dropout random seed and offset, offset should be at least the number of {seed, offset}); // dropout random seed and offset, offset should be at least the number of
// elements on a thread // elements on a thread
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
...@@ -208,6 +229,7 @@ int run(int argc, char* argv[]) ...@@ -208,6 +229,7 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, M, K});
...@@ -215,8 +237,10 @@ int run(int argc, char* argv[]) ...@@ -215,8 +237,10 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute // permute
...@@ -229,6 +253,9 @@ int run(int argc, char* argv[]) ...@@ -229,6 +253,9 @@ int run(int argc, char* argv[])
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0 // gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
...@@ -253,11 +280,22 @@ int run(int argc, char* argv[]) ...@@ -253,11 +280,22 @@ int run(int argc, char* argv[])
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// dropout after softmax
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm1 // gemm1
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n_drop,
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
......
...@@ -10,6 +10,7 @@ int run(int argc, char* argv[]) ...@@ -10,6 +10,7 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
...@@ -47,7 +48,7 @@ int run(int argc, char* argv[]) ...@@ -47,7 +48,7 @@ int run(int argc, char* argv[])
float alpha = 1; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
std::size_t group_count = 7; std::size_t group_count = 8;
// Problem descs // Problem descs
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
...@@ -76,13 +77,14 @@ int run(int argc, char* argv[]) ...@@ -76,13 +77,14 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
std::cout << "group count " << group_count << ". printing first 4 groups\n"; // std::cout << "group count " << group_count << ". printing first 4 groups\n";
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8 + 1); int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1); int N = 128 * (rand() % 8 + 1);
int K = 40; int K = 128;
int O = 40 * (rand() % 2 + 1); int O = 128;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
...@@ -231,7 +233,8 @@ int run(int argc, char* argv[]) ...@@ -231,7 +233,8 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(p_a, auto argument =
gemm.MakeArgument(p_a,
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
...@@ -246,9 +249,10 @@ int run(int argc, char* argv[]) ...@@ -246,9 +249,10 @@ int run(int argc, char* argv[])
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{0, 448}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -291,11 +295,14 @@ int run(int argc, char* argv[]) ...@@ -291,11 +295,14 @@ int run(int argc, char* argv[])
const auto& b0_gs_ns_ks = b0_tensors[i]; const auto& b0_gs_ns_ks = b0_tensors[i];
const auto& b1_gs_os_ns = b1_tensors[i]; const auto& b1_gs_os_ns = b1_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i]; auto& c_gs_ms_os_device_result = c_tensors[i];
auto& z_gs_ms_ns_device_result = z_tensors[i];
auto& lse_gs_ms_device_result = lse_tensors[i]; auto& lse_gs_ms_device_result = lse_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
auto& z_gs_ms_ns_device_buf = *z_tensors_device[i];
auto& lse_gs_ms_device_buf = *lse_tensors_device[i]; auto& lse_gs_ms_device_buf = *lse_tensors_device[i];
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
...@@ -303,8 +310,10 @@ int run(int argc, char* argv[]) ...@@ -303,8 +310,10 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1 Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -319,6 +328,10 @@ int run(int argc, char* argv[]) ...@@ -319,6 +328,10 @@ int run(int argc, char* argv[])
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0 // gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{}; auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
...@@ -342,10 +355,20 @@ int run(int argc, char* argv[]) ...@@ -342,10 +355,20 @@ int run(int argc, char* argv[])
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// printf("print z_g_m_n \n");
// z_g_m_n.ForEach([&](auto& self, auto idx) {printf("%u ", self(idx));});
// dropout after softmax
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, a1_g_m_n, a1_g_m_n_drop, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// gemm 1 // gemm 1
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n_drop,
b1_g_n_o, b1_g_n_o,
c_g_m_o_host_result, c_g_m_o_host_result,
PassThrough{}, PassThrough{},
...@@ -387,6 +410,7 @@ int run(int argc, char* argv[]) ...@@ -387,6 +410,7 @@ int run(int argc, char* argv[])
// bool pass_ = // bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData, // ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData); // c_gs_ms_os_host_result.mData);
bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData, bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData, c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!", "Error: Incorrect results c!",
...@@ -399,6 +423,10 @@ int run(int argc, char* argv[]) ...@@ -399,6 +423,10 @@ int run(int argc, char* argv[])
atol); atol);
pass &= pass_; pass &= pass_;
} }
if(pass)
{
std::cout << "Verification passed." << std::endl;
}
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -75,6 +75,7 @@ template <index_t NumDimG, ...@@ -75,6 +75,7 @@ template <index_t NumDimG,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -94,6 +95,7 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -94,6 +95,7 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const void* p_b0, const void* p_b0,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
...@@ -105,6 +107,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator ...@@ -105,6 +107,8 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths, // z_gs_ms_os_lengths
const std::vector<index_t>& z_gs_ms_ns_strides, // z_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths const std::vector<index_t>& lse_gs_ms_lengths, // lse_gs_ms_lengths
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
......
...@@ -26,6 +26,7 @@ namespace device { ...@@ -26,6 +26,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename ZDataType,
typename FloatLSE, typename FloatLSE,
typename GemmAccDataType, typename GemmAccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -37,6 +38,7 @@ template <typename GridwiseGemm, ...@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -52,6 +54,7 @@ __global__ void ...@@ -52,6 +54,7 @@ __global__ void
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
ZDataType* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -63,6 +66,8 @@ __global__ void ...@@ -63,6 +66,8 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -87,6 +92,8 @@ __global__ void ...@@ -87,6 +92,8 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
...@@ -98,6 +105,7 @@ __global__ void ...@@ -98,6 +105,7 @@ __global__ void
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -109,6 +117,7 @@ __global__ void ...@@ -109,6 +117,7 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -148,6 +157,7 @@ template <index_t NumDimG, ...@@ -148,6 +157,7 @@ template <index_t NumDimG,
typename BDataType, typename BDataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -215,6 +225,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -215,6 +225,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
BDataType, BDataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -285,6 +296,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -285,6 +296,12 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
Number<B1K1>{}); Number<B1K1>{});
} }
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
{ {
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
...@@ -314,11 +331,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -314,11 +331,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -339,11 +358,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -339,11 +358,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -368,6 +389,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -368,6 +389,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
...@@ -378,6 +404,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -378,6 +404,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -398,6 +425,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -398,6 +425,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
ZGridDesc_M_N,
LSEGridDesc_M, LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -455,6 +483,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -455,6 +483,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const BDataType* p_b_grid, const BDataType* p_b_grid,
const B1DataType* p_b1_grid, const B1DataType* p_b1_grid,
CDataType* p_c_grid, CDataType* p_c_grid,
ZDataType* p_z_grid,
LSEDataType* p_lse_grid, LSEDataType* p_lse_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
...@@ -466,6 +495,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -466,6 +495,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
...@@ -484,6 +515,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -484,6 +515,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_z_grid_{p_z_grid},
p_lse_grid_{p_lse_grid}, p_lse_grid_{p_lse_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
...@@ -493,6 +525,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -493,6 +525,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
a_grid_desc_g_m_k_{ a_grid_desc_g_m_k_{
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
...@@ -502,6 +535,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -502,6 +535,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -528,6 +563,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -528,6 +563,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
b_grid_desc_g_n_k_, b_grid_desc_g_n_k_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
...@@ -557,6 +593,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -557,6 +593,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds); offset_ = std::get<1>(seeds);
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
} }
void Print() const void Print() const
...@@ -580,6 +619,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -580,6 +619,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_; const B1DataType* p_b1_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
ZDataType* p_z_grid_;
LSEDataType* p_lse_grid_; LSEDataType* p_lse_grid_;
// tensor descriptor // tensor descriptor
...@@ -587,13 +627,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -587,13 +627,18 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -652,6 +697,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -652,6 +697,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
GemmAccDataType, GemmAccDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -663,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -663,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
...@@ -679,6 +726,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -679,6 +726,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.p_b_grid_, arg.p_b_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_z_grid_,
arg.p_lse_grid_, arg.p_lse_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -689,6 +737,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -689,6 +737,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
...@@ -827,6 +876,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -827,6 +876,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const BDataType* p_b, const BDataType* p_b,
const B1DataType* p_b1, const B1DataType* p_b1,
CDataType* p_c, CDataType* p_c,
ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
...@@ -838,6 +888,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -838,6 +888,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
...@@ -857,6 +909,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -857,6 +909,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
p_b, p_b,
p_b1, p_b1,
p_c, p_c,
p_z,
p_lse, p_lse,
p_acc0_biases, p_acc0_biases,
p_acc1_biases, p_acc1_biases,
...@@ -868,6 +921,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -868,6 +921,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
...@@ -891,6 +946,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -891,6 +946,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const void* p_b, const void* p_b,
const void* p_b1, const void* p_b1,
void* p_c, void* p_c,
void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
...@@ -902,6 +958,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -902,6 +958,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
...@@ -921,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -921,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
static_cast<const B1DataType*>(p_b1), static_cast<const B1DataType*>(p_b1),
static_cast<CDataType*>(p_c), static_cast<CDataType*>(p_c),
static_cast<ZDataType*>(p_z),
static_cast<LSEDataType*>(p_lse), static_cast<LSEDataType*>(p_lse),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_biases, // cast in struct Argument
...@@ -932,6 +991,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -932,6 +991,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
......
...@@ -32,7 +32,8 @@ template <typename GridwiseGemm, ...@@ -32,7 +32,8 @@ template <typename GridwiseGemm,
typename B1ElementwiseOperation, typename B1ElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout> bool IsDropout,
bool IsLseStoring>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -97,18 +98,16 @@ __global__ void ...@@ -97,18 +98,16 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
// unsigned short* p_z_grid_in = // GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -119,7 +118,7 @@ __global__ void ...@@ -119,7 +118,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, //////// arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -417,6 +416,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -417,6 +416,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -588,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -588,6 +588,11 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]); const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]); const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
...@@ -621,7 +626,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -621,7 +626,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
// typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 // typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
z_grid_desc_m_n); z_grid_desc_m_n);
...@@ -722,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -722,6 +728,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long offset_; unsigned long long offset_;
GemmAccDataType p_dropout_rescale_; GemmAccDataType p_dropout_rescale_;
bool is_dropout_; bool is_dropout_;
bool is_lse_storing_ = true;
}; };
// Invoker // Invoker
...@@ -754,7 +762,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -754,7 +762,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
GemmAccDataType, GemmAccDataType,
...@@ -765,7 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -765,7 +774,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
B1ElementwiseOperation, B1ElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_>; is_dropout_,
is_lse_storing_>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -791,30 +801,70 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -791,30 +801,70 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
if(all_has_main_k_block_loop) if(all_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else if(!some_has_main_k_block_loop) else if(!some_has_main_k_block_loop)
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{
if(arg.is_lse_storing_)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
else else
{
if(arg.is_lse_storing_)
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
}
else
{ {
throw std::runtime_error("wrong! all gemm problems have to simultaneously meet " throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
"has_main_k_block_loop or no_main_k_block_loop"); "has_main_k_block_loop or no_main_k_block_loop");
......
...@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -139,23 +139,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
} }
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M,
const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx() __device__ static auto GetGemm0WaveIdx()
{ {
...@@ -290,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -290,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
...@@ -427,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -427,6 +416,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring,
typename Block2CTileMap, typename Block2CTileMap,
typename C0MatrixMask> typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
...@@ -851,8 +841,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -851,8 +841,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout // z is random number matrix for dropout verify
// //
// z vgpr copy to global // z vgpr copy to global
// //
...@@ -876,11 +865,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -876,11 +865,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer; z_tenor_buffer;
z_tenor_buffer.Clear(); z_tenor_buffer.Clear();
// z matrix global desc // z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize()); p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
...@@ -922,8 +906,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -922,8 +906,6 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
///////////////////=>z for dropout
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
...@@ -1025,7 +1007,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1025,7 +1007,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer), decltype(z_tenor_buffer),
true>( false>(
acc_thread_buf, ph, z_tenor_buffer); acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run( z_thread_copy_vgpr_to_global.Run(
...@@ -1034,20 +1016,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1034,20 +1016,19 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
z_tenor_buffer, z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf); z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), false>(
acc_thread_buf, ph); acc_thread_buf, ph);
} }
} }
// if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
...@@ -1168,6 +1149,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1168,6 +1149,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// Calculate max + ln(sum) and write out // Calculate max + ln(sum) and write out
if constexpr(IsLseStoring)
{
static_for<0, MXdlPerWave, 1>{}( static_for<0, MXdlPerWave, 1>{}(
[&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); }); [&](auto I) { lse_thread_buf(I) = running_max(I) + math::log(running_sum(I)); });
...@@ -1185,6 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1185,6 +1169,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0)); lse_grid_desc_mblock_mrepeat_mwave_mperxdl, make_multi_index(0, 1, 0, 0));
}); });
} }
}
// shuffle C and write out // shuffle C and write out
{ {
......
...@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -31,14 +31,14 @@ struct ReferenceDropout : public device::BaseOperator
in_(in), in_(in),
out_(out), out_(out),
p_dropout_in_16bits_(p_dropout_in_16bits), p_dropout_in_16bits_(p_dropout_in_16bits),
rp_dropout_(ck::type_convert<OutDataType>(rp_dropout)) rp_dropout_(rp_dropout)
{ {
} }
const Tensor<RefDataType>& ref_; const Tensor<RefDataType>& ref_;
const Tensor<InDataType>& in_; const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_; Tensor<OutDataType>& out_;
RefDataType p_dropout_in_16bits_; RefDataType p_dropout_in_16bits_;
OutDataType rp_dropout_; float rp_dropout_;
}; };
// Invoker // Invoker
...@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator ...@@ -48,7 +48,10 @@ struct ReferenceDropout : public device::BaseOperator
{ {
arg.out_.ForEach([&](auto& self, auto idx) { arg.out_.ForEach([&](auto& self, auto idx) {
self(idx) = self(idx) =
arg.ref_(idx) <= arg.p_dropout_in_16bits_ ? arg.in_(idx) * arg.rp_dropout_ : 0; arg.ref_(idx) <= arg.p_dropout_in_16bits_
? ck::type_convert<OutDataType>(ck::type_convert<float>(arg.in_(idx)) *
ck::type_convert<float>(arg.rp_dropout_))
: 0;
}); });
return 0; return 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment