Commit 627016c1 authored by fsx950223's avatar fsx950223
Browse files

Merge remote-tracking branch 'origin/attn-bwd-develop' into grouped_api

parents 83b53ec8 906bbc60
...@@ -25,11 +25,13 @@ Kernel outputs: ...@@ -25,11 +25,13 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 1 #define USING_MASK 1
#define USING_K128 1
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...@@ -60,6 +62,7 @@ using YElementOp = PassThrough; ...@@ -60,6 +62,7 @@ using YElementOp = PassThrough;
using VElementOp = Scale; using VElementOp = Scale;
using DataType = F16; using DataType = F16;
using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -87,6 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -87,6 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if USING_K128
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
NumDimG, NumDimG,
...@@ -95,6 +99,7 @@ using DeviceGemmInstance = ...@@ -95,6 +99,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -115,9 +120,9 @@ using DeviceGemmInstance = ...@@ -115,9 +120,9 @@ using DeviceGemmInstance =
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 64, // KPerBlock
128, // Gemm1NPerBlock 128, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -126,6 +131,7 @@ using DeviceGemmInstance = ...@@ -126,6 +131,7 @@ using DeviceGemmInstance =
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -153,6 +159,75 @@ using DeviceGemmInstance = ...@@ -153,6 +159,75 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
64, // KPerBlock
64, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
2,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
...@@ -222,14 +297,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -222,14 +297,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
#endif
// P = Softmax(S) // P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -264,10 +337,15 @@ int run(int argc, char* argv[]) ...@@ -264,10 +337,15 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = 128; #if USING_K128
ck::index_t O = 128; ck::index_t K = 128;
ck::index_t O = 128;
#else
ck::index_t K = 64;
ck::index_t O = 64;
#endif
ck::index_t G0 = 3; ck::index_t G0 = 3;
ck::index_t G1 = 2; ck::index_t G1 = 2;
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
...@@ -175,7 +175,7 @@ int run(int argc, char* argv[]) ...@@ -175,7 +175,7 @@ int run(int argc, char* argv[])
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
...@@ -228,6 +228,44 @@ int run(int argc, char* argv[]) ...@@ -228,6 +228,44 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor
argument = gemm.MakeArgument(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number
// of elements on a thread
c_device_buf.SetZero();
lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = true;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -56,7 +56,8 @@ int run(int argc, char* argv[]) ...@@ -56,7 +56,8 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<void*> p_z; std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
...@@ -221,6 +222,7 @@ int run(int argc, char* argv[]) ...@@ -221,6 +222,7 @@ int run(int argc, char* argv[])
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer()); p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -233,12 +235,13 @@ int run(int argc, char* argv[]) ...@@ -233,12 +235,13 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = auto argument =
gemm.MakeArgument(p_a, gemm.MakeArgument(p_a,
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
p_z, p_z_nullptr,
p_lse, p_lse,
{}, // p_acc0_biases {}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
...@@ -252,7 +255,6 @@ int run(int argc, char* argv[]) ...@@ -252,7 +255,6 @@ int run(int argc, char* argv[])
{seed, offset}); // dropout random seed and offset, offset should be {seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread // at least the number of elements on a thread
// specify workspace for problem_desc // specify workspace for problem_desc
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
...@@ -277,6 +279,31 @@ int run(int argc, char* argv[]) ...@@ -277,6 +279,31 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
argument =
gemm.MakeArgument(p_a,
p_b0,
p_b1,
p_c,
p_z,
p_lse,
{}, // p_acc0_biases
{}, // p_acc1_biases
problem_descs,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op,
p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1_m_n_k_o[i][0];
......
...@@ -17,7 +17,7 @@ struct BlockwiseDropout ...@@ -17,7 +17,7 @@ struct BlockwiseDropout
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
template <typename CThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph) __host__ __device__ void ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -52,7 +52,7 @@ struct BlockwiseDropout ...@@ -52,7 +52,7 @@ struct BlockwiseDropout
template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false> template <typename CThreadBuffer, typename ZThreadBuffer, bool using_sign_bit = false>
__host__ __device__ void __host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox ph, ZThreadBuffer& z_thread_buf) ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{ {
auto execute_dropout = [&](bool keep, DataType val) { auto execute_dropout = [&](bool keep, DataType val) {
...@@ -86,6 +86,42 @@ struct BlockwiseDropout ...@@ -86,6 +86,42 @@ struct BlockwiseDropout
}); });
} }
template <typename CThreadBuffer,
typename ZThreadBuffer,
bool using_sign_bit,
typename N0,
typename Offset>
__host__ __device__ void
ApplyDropout(CThreadBuffer& in_thread_buf, ck::philox& ph, ZThreadBuffer& z_thread_buf)
{
auto execute_dropout = [&](bool keep, DataType val) {
if constexpr(using_sign_bit)
return keep ? val : -val;
else
return keep ? val * p_dropout_rescale : float(0);
};
constexpr int tmp_size = MRepeat * KRepeat / N0{}.value;
int philox_calls = tmp_size / 8;
ushort tmp[tmp_size];
for(int i = 0; i < philox_calls; i++)
{
ph.get_random_8x16((tmp + i * 8));
}
block_sync_lds();
constexpr auto iOffset = Number<tmp_size>{} * Offset{};
static_for<0, tmp_size, 1>{}([&](auto i) {
in_thread_buf(i + iOffset) =
execute_dropout(tmp[i.value] <= p_dropout_16bits, in_thread_buf(i + iOffset));
z_thread_buf(i) = tmp[i.value];
});
}
ushort p_dropout_16bits; ushort p_dropout_16bits;
DataType p_dropout_rescale; DataType p_dropout_rescale;
}; };
......
...@@ -49,7 +49,7 @@ template <typename GridwiseGemm, ...@@ -49,7 +49,7 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
...@@ -171,6 +171,7 @@ template <index_t NumDimG, ...@@ -171,6 +171,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -202,6 +203,7 @@ template <index_t NumDimG, ...@@ -202,6 +203,7 @@ template <index_t NumDimG,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -595,9 +597,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -595,9 +597,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -625,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -625,6 +628,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
Gemm1NXdlPerWave, Gemm1NXdlPerWave,
Gemm2NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
......
...@@ -44,7 +44,8 @@ template <typename GridwiseGemm, ...@@ -44,7 +44,8 @@ template <typename GridwiseGemm,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout> bool IsDropout,
bool IsLseStoring>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -100,13 +101,13 @@ __global__ void ...@@ -100,13 +101,13 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>( GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset, p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset, nullptr ? nullptr : p_z_grid + z_batch_offset,
p_lse_grid + lse_batch_offset, nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -596,6 +597,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -596,6 +597,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
}
} }
void Print() const void Print() const
...@@ -669,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -669,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
unsigned long long seed_; unsigned long long seed_;
unsigned long long offset_; unsigned long long offset_;
bool is_dropout_; bool is_dropout_;
bool is_lse_storing_ = true;
}; };
// Invoker // Invoker
...@@ -692,7 +700,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -692,7 +700,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) { auto launch_kernel = [&](auto has_main_k_block_loop_,
auto is_dropout_,
auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
...@@ -715,7 +725,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -715,7 +725,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask, C0MatrixMask,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_>; is_dropout_,
is_lse_storing_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -755,26 +766,66 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -755,26 +766,66 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, if(arg.is_lse_storing_)
integral_constant<bool, true>{}); {
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}, if(arg.is_lse_storing_)
integral_constant<bool, false>{}); {
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
} }
else else
{ {
if(arg.is_dropout_) if(arg.is_dropout_)
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, if(arg.is_lse_storing_)
integral_constant<bool, true>{}); {
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
} }
else else
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}, if(arg.is_lse_storing_)
integral_constant<bool, false>{}); {
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
} }
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace ck { namespace ck {
template <typename DataType, template <typename DataType,
typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatLSE, typename FloatLSE,
...@@ -51,6 +52,7 @@ template <typename DataType, ...@@ -51,6 +52,7 @@ template <typename DataType,
index_t MXdlPerWave, index_t MXdlPerWave,
index_t NXdlPerWave, index_t NXdlPerWave,
index_t Gemm1NXdlPerWave, index_t Gemm1NXdlPerWave,
index_t Gemm2NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
...@@ -85,21 +87,6 @@ template <typename DataType, ...@@ -85,21 +87,6 @@ template <typename DataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
template <typename T>
struct TypeMap
{
using type = T;
};
#if defined(__gfx90a__)
template <>
struct TypeMap<ck::half_t>
{
using type = ck::bhalf_t;
};
#endif
using LDSDataType = typename TypeMap<DataType>::type;
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -141,7 +128,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -141,7 +128,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const auto M = z_grid_desc_m_n.GetLength(I0); const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1); const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -157,7 +144,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -157,7 +144,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{ {
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma; constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks; constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size; constexpr auto N5 = mfma.group_size;
...@@ -471,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -471,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
LDSDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -496,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -496,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
LDSDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -513,12 +500,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -513,12 +500,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t KPack = static constexpr index_t KPack =
math::max(math::lcm(AK1, BK1), math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
LDSDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -580,7 +567,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -580,7 +567,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
LDSDataType, GemmDataType,
decltype(a_src_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -599,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -599,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, DataType,
LDSDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -630,11 +617,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -630,11 +617,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack =
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
LDSDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -650,7 +637,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -650,7 +637,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
GemmKPack, GemmKPack,
true, // TransposeC true, // TransposeC
GemmKPack, // AMmaKStride GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<LDSDataType, MPerXdl, NPerXdl, GemmKPack, false>{} GemmKPack * XdlopsGemm<GemmDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>; .K0PerXdlops /* BMmaKStride */>;
}; };
...@@ -676,13 +663,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -676,13 +663,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t BSrcVectorDim = 1; // Free1_O dimension static constexpr index_t BSrcVectorDim = 1; // Free1_O dimension
static constexpr index_t BSrcScalarPerVector = 4; static constexpr index_t BSrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2; static constexpr index_t GemmNWave = Free0_N / Gemm2NXdlPerWave / MPerXdl;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave; static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl; static constexpr index_t GemmNRepeat = Gemm2NXdlPerWave;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(math::lcm(A_M1, B_M1), math::max(math::lcm(A_M1, B_M1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>; using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using BThreadClusterLengths = using BThreadClusterLengths =
...@@ -807,7 +794,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -807,7 +794,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough> template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
LDSDataType, GemmDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp, ElementwiseOp,
...@@ -837,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -837,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BThreadClusterLengths, typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType, DataType,
LDSDataType, GemmDataType,
GridDesc_M0_O_M1, GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
...@@ -854,7 +841,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -854,7 +841,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
LDSDataType, GemmDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_m0_n_m1), decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
...@@ -1095,7 +1082,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1095,7 +1082,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr auto b2_block_desc_m0_o_m1 = static constexpr auto b2_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>(); GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(LDSDataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(GemmDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -1131,13 +1118,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1131,13 +1118,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(LDSDataType); sizeof(GemmDataType);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(LDSDataType); sizeof(GemmDataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned + const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) * SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(LDSDataType); sizeof(GemmDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
...@@ -1188,9 +1175,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1188,9 +1175,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const float p_drop, const float p_drop,
ck::philox& ph) ck::philox& ph)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); const ushort p_dropout_in_16bits =
__builtin_amdgcn_readfirstlane(std::floor(p_dropout * 65535.0));
const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() * const tensor_operation::element_wise::Scale scale_rp_dropout(s_element_op.Value() *
rp_dropout); rp_dropout);
...@@ -1243,11 +1231,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1243,11 +1231,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm0: LDS allocation for A and B: be careful of alignment // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize()); Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline // Gemm0: gridwise GEMM pipeline
...@@ -1339,11 +1327,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1339,11 +1327,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and LDS allocation for B // Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, LDSDataType>( auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, GemmDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize()); Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors // dQ: transform input and output tensor descriptors
...@@ -1467,7 +1455,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1467,7 +1455,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1503,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1503,7 +1491,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<I1, // MBlockId Sequence<I1, // MBlockId
I1, // NBlockID I1, // NBlockID
m0, // MRepeat m0, // MRepeat
n0, // NRepeat I1, // NRepeat
m1, // MWaveId m1, // MWaveId
n1, // NWaveId n1, // NWaveId
m2, // MPerXdl m2, // MPerXdl
...@@ -1535,11 +1523,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1535,11 +1523,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize()); Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset, static_cast<GemmDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize()); Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
...@@ -1868,19 +1856,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1868,19 +1856,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
if(p_z_grid) if(p_z_grid)
{ {
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), static_for<0, n0, 1>{}([&](auto i) {
decltype(z_tenor_buffer), blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
true>( decltype(z_tenor_buffer),
s_slash_p_thread_buf, ph, z_tenor_buffer); true,
decltype(n0),
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, decltype(i)>(
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), s_slash_p_thread_buf, ph, z_tenor_buffer);
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_thread_copy_vgpr_to_global.Run(
z_grid_buf); z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
});
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 0, 0, -n0.value, 0, 0, 0, 0, 0, 0));
} }
else else
{ {
ignore = z_grid_buf;
// P_dropped // P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>( blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph); s_slash_p_thread_buf, ph);
......
...@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K) // if(Gemm1N != K)
{ //{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; // std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; // return false;
} //}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment