Commit 66052232 authored by danyao12's avatar danyao12
Browse files

sync attn-bwd-dropout

parents 5eb5e316 bf80ceee
...@@ -10,8 +10,12 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_ ...@@ -10,8 +10,12 @@ add_example_executable(example_batched_multihead_attention_forward_fp16 batched_
add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp) add_example_executable(example_grouped_multihead_attention_forward_bf16 grouped_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp) add_example_executable(example_batched_multihead_attention_forward_bf16 batched_multihead_attention_forward_bf16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16 batched_multihead_attention_backward_fp16.cpp)
<<<<<<< HEAD
add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp) add_example_executable(example_batched_multihead_attention_backward_pt1_fp16 batched_multihead_attention_backward_pt1_fp16.cpp)
add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp) add_example_executable(example_batched_multihead_attention_backward_fp16_dropout batched_multihead_attention_backward_fp16_dropout.cpp)
=======
>>>>>>> attn-bwd-dropout
add_custom_target(example_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
......
...@@ -43,23 +43,27 @@ Kernel outputs: ...@@ -43,23 +43,27 @@ Kernel outputs:
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using VElementOp = Scale;
using DataType = F16; using DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -91,6 +95,7 @@ using DeviceGemmInstance = ...@@ -91,6 +95,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe ...@@ -182,12 +187,16 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
typename TensorV, typename TensorV,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorZ,
typename TensorY, typename TensorY,
typename TensorLSE = TensorP> typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
...@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -197,7 +206,11 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
TensorY& y_g_m_o, TensorY& y_g_m_o,
TensorLSE& lse_g_m) TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
float rp_dropout)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -225,11 +238,18 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V // P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{}; auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument( auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{}); p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument); ref_gemm1_invoker.Run(ref_gemm1_argument);
} }
...@@ -256,6 +276,13 @@ int run(int argc, char* argv[]) ...@@ -256,6 +276,13 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -321,6 +348,11 @@ int run(int argc, char* argv[]) ...@@ -321,6 +348,11 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
...@@ -332,6 +364,7 @@ int run(int argc, char* argv[]) ...@@ -332,6 +364,7 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
...@@ -339,10 +372,12 @@ int run(int argc, char* argv[]) ...@@ -339,10 +372,12 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -408,9 +443,11 @@ int run(int argc, char* argv[]) ...@@ -408,9 +443,11 @@ int run(int argc, char* argv[])
// calculate y & log-sum-exp beforehand // calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
...@@ -418,12 +455,25 @@ int run(int argc, char* argv[]) ...@@ -418,12 +455,25 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m); run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach( y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
...@@ -433,6 +483,7 @@ int run(int argc, char* argv[]) ...@@ -433,6 +483,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
...@@ -443,6 +494,7 @@ int run(int argc, char* argv[]) ...@@ -443,6 +494,7 @@ int run(int argc, char* argv[])
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
...@@ -450,11 +502,59 @@ int run(int argc, char* argv[]) ...@@ -450,11 +502,59 @@ int run(int argc, char* argv[])
kgrad_device_buf.SetZero(); kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// get z matrix
{
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
invoker.Run(argument, StreamConfig{nullptr, false});
}
// not need output z matrix
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
...@@ -468,6 +568,8 @@ int run(int argc, char* argv[]) ...@@ -468,6 +568,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
k_gs_ns_ks_strides, k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths, v_gs_os_ns_lengths,
v_gs_os_ns_strides, v_gs_os_ns_strides,
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
...@@ -481,15 +583,11 @@ int run(int argc, char* argv[]) ...@@ -481,15 +583,11 @@ int run(int argc, char* argv[])
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}); YElementOp{},
p_drop,
if(!gemm.IsSupportedArgument(argument)) std::tuple<unsigned long long, unsigned long long>(seed, offset));
{ kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; vgrad_device_buf.SetZero();
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// 5 GEMM ops in total: // 5 GEMM ops in total:
...@@ -511,9 +609,32 @@ int run(int argc, char* argv[]) ...@@ -511,9 +609,32 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
// copy z matirx data form device
z_device_buf.FromDevice(z_g_m_n.mData.data());
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
// call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
...@@ -523,6 +644,7 @@ int run(int argc, char* argv[]) ...@@ -523,6 +644,7 @@ int run(int argc, char* argv[])
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
...@@ -544,20 +666,26 @@ int run(int argc, char* argv[]) ...@@ -544,20 +666,26 @@ int run(int argc, char* argv[])
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T // dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dP = dY * V^T\n"; std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n; std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n; std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
} }
#endif #endif
// dP = dP_dropout x Z
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0; float ygrad_dot_y = 0;
for(int o = 0; o < O; o++) for(int o = 0; o < O; o++)
...@@ -578,15 +706,14 @@ int run(int argc, char* argv[]) ...@@ -578,15 +706,14 @@ int run(int argc, char* argv[])
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n; std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
} }
#endif #endif
// dV = P_drop^T * dY
// dV = P^T * dY auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST #if PRINT_HOST
{ {
std::cout << "===== dV = P^T * dY\n"; std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m; std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o; std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
Computation graph:
K^T V
| |
| |
Q --- * ----- Softmax ----- * --> Y
S P
Kernel inputs:
Q, K, V, Y, dY, per-row softmax stats (LSE)
Kernel outputs:
dQ, dK, dV
*/
#define PRINT_HOST 0
#define USING_MASK 1
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_train_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using VElementOp = Scale;
using DataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
using ZDataType = U16;
using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>;
static constexpr ck::index_t NumDimG = 2;
static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;
#else
static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
// Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out
using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, DataType, AccDataType>;
// Ref Gemm1: Y = P * V
// fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
// Ref Gemm for backward pass
// fp16 in, fp16 out
using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout
using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorZ,
typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o,
TensorLSE& lse_g_m,
TensorP& p_drop_g_m_n,
TensorZ& z_g_m_n,
ushort p_dropout_in_16bits,
float rp_dropout)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument);
// P_dropped
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment =
ref_dropout.MakeArgument(z_g_m_n, p_g_m_n, p_drop_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// Y = P_dropout * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_drop_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 2; // method 1 will have slightly higher error; TODO: to investigate
bool time_kernel = true;
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = 128;
ck::index_t O = 128;
ck::index_t G0 = 3;
ck::index_t G1 = 2;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0);
}
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ks: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method)
{
case 0: break;
case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = [0, 1, 2, ...; 0, 1, 2, ...; ...]
// dO dot O = [127.5; ...]
// dS = P * (dP - dO dot O)
//
break;
default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
// dP = dO V = ones
// dS = P * (dP - (dO dot O))
// = 0.0039 * ones * (ones - 0.0039*256)
// = 0.0039 * ones * (ones - 1)
// = 0
}
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
// z_device_buf.SetZero();
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// 5 GEMM ops in total:
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
// 3x MNK + 2x MNO
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
size_t(2) * BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// copy z matirx data form device
std::ofstream file("./z_matrix_txt");
z_device_buf.FromDevice(z_g_m_n.mData.data());
file << z_g_m_n << std::endl;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
if(do_verification)
{
// run fowad again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
//
// call kernel again
//
// example set Z matrix to null, will not ouput z matrix data
argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
k_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
v_gs_os_ns_lengths,
v_gs_os_ns_strides,
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false});
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
{
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
}
#endif
// Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_drop_g_m_o ref:\n" << ygrad_drop_g_m_n;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_drop_g_m_n ref:\n" << pgrad_drop_g_m_n;
}
#endif
// dP = dP_dropout x Z
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_n, pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
}
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
});
#if PRINT_HOST
{
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "y_g_m_o ref:\n" << y_g_m_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
#endif
// dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_drop_g_n_m ref:\n" << p_drop_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
#endif
// dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
}
#endif
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST
{
std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
}
#endif
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
}
int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>; ...@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -42,6 +43,7 @@ using B1DataType = F16; ...@@ -42,6 +43,7 @@ using B1DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F16; using CDataType = F16;
using ZDataType = U16;
using LSEDataType = F32; using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial ...@@ -69,6 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -78,6 +81,7 @@ using DeviceGemmInstance = ...@@ -78,6 +81,7 @@ using DeviceGemmInstance =
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
...@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -159,4 +163,5 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_grouped_multihead_attention_forward.inc" #include "run_grouped_multihead_attention_forward.inc"
int main(int argc, char* argv[]) { return run(argc, argv); } int main(int argc, char* argv[]) { return run(argc, argv); }
...@@ -48,6 +48,7 @@ int run(int argc, char* argv[]) ...@@ -48,6 +48,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_b0; std::vector<const void*> p_b0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<void*> p_z;
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1_m_n_k_o;
...@@ -55,6 +56,7 @@ int run(int argc, char* argv[]) ...@@ -55,6 +56,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors; std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<CDataType>> c_tensors; std::vector<Tensor<CDataType>> c_tensors;
std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
...@@ -62,6 +64,7 @@ int run(int argc, char* argv[]) ...@@ -62,6 +64,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> b0_tensors_device; std::vector<DeviceMemPtr> b0_tensors_device;
std::vector<DeviceMemPtr> b1_tensors_device; std::vector<DeviceMemPtr> b1_tensors_device;
std::vector<DeviceMemPtr> c_tensors_device; std::vector<DeviceMemPtr> c_tensors_device;
std::vector<DeviceMemPtr> z_tensors_device;
std::vector<DeviceMemPtr> lse_tensors_device; std::vector<DeviceMemPtr> lse_tensors_device;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
...@@ -101,6 +104,12 @@ int run(int argc, char* argv[]) ...@@ -101,6 +104,12 @@ int run(int argc, char* argv[])
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
...@@ -114,6 +123,8 @@ int run(int argc, char* argv[]) ...@@ -114,6 +123,8 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
{}, // acc0_biases_gs_ms_ns_lengths {}, // acc0_biases_gs_ms_ns_lengths
...@@ -126,6 +137,7 @@ int run(int argc, char* argv[]) ...@@ -126,6 +137,7 @@ int run(int argc, char* argv[])
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1; int Batch = G0 * G1;
...@@ -140,10 +152,13 @@ int run(int argc, char* argv[]) ...@@ -140,10 +152,13 @@ int run(int argc, char* argv[])
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", " << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << ", "
<< "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc << "lse_gs_ms_os[" << i << "]: " << lse_gs_ms_device_result.mDesc
<< std::endl; << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
...@@ -172,6 +187,7 @@ int run(int argc, char* argv[]) ...@@ -172,6 +187,7 @@ int run(int argc, char* argv[])
b0_tensors.push_back(b0_gs_ns_ks); b0_tensors.push_back(b0_gs_ns_ks);
b1_tensors.push_back(b1_gs_os_ns); b1_tensors.push_back(b1_gs_os_ns);
c_tensors.push_back(c_gs_ms_os_device_result); c_tensors.push_back(c_gs_ms_os_device_result);
z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms_device_result); lse_tensors.push_back(lse_gs_ms_device_result);
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
...@@ -182,6 +198,8 @@ int run(int argc, char* argv[]) ...@@ -182,6 +198,8 @@ int run(int argc, char* argv[])
sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize())); sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
z_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()));
lse_tensors_device.emplace_back(std::make_unique<DeviceMem>( lse_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(LSEDataType) * lse_gs_ms_device_result.mDesc.GetElementSpaceSize())); sizeof(LSEDataType) * lse_gs_ms_device_result.mDesc.GetElementSpaceSize()));
...@@ -193,6 +211,7 @@ int run(int argc, char* argv[]) ...@@ -193,6 +211,7 @@ int run(int argc, char* argv[])
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_z.push_back(z_tensors_device[i]->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device[i]->GetDeviceBuffer());
} }
...@@ -209,6 +228,7 @@ int run(int argc, char* argv[]) ...@@ -209,6 +228,7 @@ int run(int argc, char* argv[])
p_b0, p_b0,
p_b1, p_b1,
p_c, p_c,
p_z,
p_lse, p_lse,
{}, // p_acc0_biases {}, // p_acc0_biases
{}, // p_acc1_biases {}, // p_acc1_biases
......
...@@ -79,6 +79,7 @@ template <index_t NumDimG, ...@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -104,6 +105,9 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<index_t> z_gs_ms_ns_lengths;
std::vector<index_t> z_gs_ms_ns_strides;
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
...@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator ...@@ -119,6 +123,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<std::vector<const void*>> p_acc1_biases_vec,
......
...@@ -29,6 +29,7 @@ namespace device { ...@@ -29,6 +29,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename DataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -37,6 +38,7 @@ template <typename GridwiseGemm, ...@@ -37,6 +38,7 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
...@@ -50,9 +52,10 @@ __global__ void ...@@ -50,9 +52,10 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const DataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
...@@ -67,6 +70,8 @@ __global__ void ...@@ -67,6 +70,8 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
...@@ -76,7 +81,10 @@ __global__ void ...@@ -76,7 +81,10 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask,
const float p_dropout,
const unsigned long long seed,
const unsigned long long offset)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -90,6 +98,8 @@ __global__ void ...@@ -90,6 +98,8 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
...@@ -97,8 +107,13 @@ __global__ void ...@@ -97,8 +107,13 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
z_matrix_ptr,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
...@@ -114,13 +129,16 @@ __global__ void ...@@ -114,13 +129,16 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask); c0_matrix_mask,
p_dropout,
ph);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -151,6 +169,7 @@ template <index_t NumDimG, ...@@ -151,6 +169,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -429,6 +448,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -429,6 +448,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
} }
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
}
// //
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// //
...@@ -489,9 +514,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -489,9 +514,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {})); using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
constexpr static auto make_MaskOutPredicate() constexpr static auto make_MaskOutPredicate()
{ {
...@@ -510,11 +537,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -510,11 +537,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
...@@ -531,6 +560,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -531,6 +560,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ {
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -549,13 +583,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -549,13 +583,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, LSEDataType,
GemmAccDataType, GemmAccDataType,
...@@ -568,6 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -568,6 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
ZGridDesc_M_N,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
YGridDesc_M_O, YGridDesc_M_O,
LSEGridDesc_M, LSEGridDesc_M,
...@@ -624,6 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -624,6 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
Argument( Argument(
const DataType* p_a_grid, const DataType* p_a_grid,
const DataType* p_b_grid, const DataType* p_b_grid,
ZDataType* p_z_grid,
const DataType* p_b1_grid, const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS const DataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid, const LSEDataType* p_lse_grid,
...@@ -637,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -637,6 +675,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
...@@ -652,9 +692,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -652,9 +692,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid}, p_lse_grid_{p_lse_grid},
...@@ -666,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -666,6 +709,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -683,6 +727,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -683,6 +727,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -707,6 +753,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -707,6 +753,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_, b_grid_desc_g_n_k_,
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
...@@ -729,6 +776,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -729,6 +776,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
y_grid_desc_m_o_); y_grid_desc_m_o_);
} }
p_dropout_ = 1.f - p_drop;
float rp_dropout_ = 1.f / p_dropout_;
acc_element_op_.Append(rp_dropout_);
seed_ = std::get<0>(seeds);
offset_ = std::get<1>(seeds);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
// Print(); // Print();
} }
...@@ -760,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -760,6 +817,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// pointers // pointers
const DataType* p_a_grid_; const DataType* p_a_grid_;
const DataType* p_b_grid_; const DataType* p_b_grid_;
ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const DataType* p_b1_grid_;
const DataType* p_c_grid_; const DataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
...@@ -771,6 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -771,6 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -782,9 +841,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -782,9 +841,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -807,6 +870,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -807,6 +870,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_dropout_;
unsigned long long seed_;
unsigned long long offset_;
}; };
// Invoker // Invoker
...@@ -831,9 +898,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -831,9 +898,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType,
LSEDataType, LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -842,6 +910,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -842,6 +910,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::LSEGridDesc_M, DeviceOp::LSEGridDesc_M,
...@@ -859,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -859,6 +928,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
0, 0,
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_z_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_lse_grid_, arg.p_lse_grid_,
...@@ -873,6 +943,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -873,6 +943,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
...@@ -881,7 +952,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -881,7 +952,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_); arg.c0_matrix_mask_,
arg.p_dropout_,
arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...@@ -992,6 +1066,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -992,6 +1066,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static auto MakeArgument( static auto MakeArgument(
const DataType* p_a, const DataType* p_a,
const DataType* p_b, const DataType* p_b,
ZDataType* p_z,
const DataType* p_b1, const DataType* p_b1,
const DataType* p_c, const DataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
...@@ -1005,6 +1080,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1005,6 +1080,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
...@@ -1020,10 +1097,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1020,10 +1097,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
p_z,
p_b1, p_b1,
p_c, p_c,
p_lse, p_lse,
...@@ -1037,6 +1117,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1037,6 +1117,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
b_gs_ns_ks_strides, b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
...@@ -1050,7 +1132,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1050,7 +1132,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op}; c_element_op,
p_drop,
seeds};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -1060,6 +1144,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1060,6 +1144,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
std::unique_ptr<BaseArgument> MakeArgumentPointer( std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a,
const void* p_b, const void* p_b,
void* p_z,
const void* p_b1, const void* p_b1,
const void* p_c, const void* p_c,
const void* p_lse, const void* p_lse,
...@@ -1073,6 +1158,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1073,6 +1158,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides, const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
...@@ -1088,10 +1175,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1088,10 +1175,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op, AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op, B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) // override CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const DataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1), static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const DataType*>(p_c),
static_cast<const LSEDataType*>(p_lse), static_cast<const LSEDataType*>(p_lse),
...@@ -1105,6 +1195,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1105,6 +1195,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
b_gs_ns_ks_strides, b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
...@@ -1118,7 +1210,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1118,7 +1210,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
b_element_op, b_element_op,
acc_element_op, acc_element_op,
b1_element_op, b1_element_op,
c_element_op); c_element_op,
p_drop,
seeds);
} }
// polymorphic // polymorphic
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp" #include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
// #include "ck/tensor_operation/gpu/device/device_batched_multihead_attention_backward.hpp" // TODO #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" namespace ck {
namespace tensor_operation {
#include "ck/library/utility/host_tensor.hpp" namespace device {
namespace ck { template <typename GridwiseGemm,
namespace tensor_operation { typename GemmAccDataType,
namespace device { typename GroupKernelArg,
typename AElementwiseOperation,
template <typename GridwiseGemm, typename BElementwiseOperation,
typename DataType, typename AccElementwiseOperation,
typename ZDataType, typename B1ElementwiseOperation,
typename LSEDataType, typename CElementwiseOperation,
typename AElementwiseOperation, bool HasMainKBlockLoop,
typename BElementwiseOperation, bool IsDropout>
typename AccElementwiseOperation, __global__ void
typename B1ElementwiseOperation, #if CK_USE_LAUNCH_BOUNDS
typename CElementwiseOperation, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
typename AGridDesc_AK0_M_AK1, #endif
typename BGridDesc_BK0_N_BK1, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
typename B1GridDesc_BK0_N_BK1, const index_t group_count,
typename YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, const AElementwiseOperation a_element_op,
typename LSEGridDescriptor_M, const BElementwiseOperation b_element_op,
typename VGradGridDescriptor_N_O, const AccElementwiseOperation acc_element_op,
typename YGradGridDesc_M0_O_M1, const B1ElementwiseOperation b1_element_op,
typename Block2CTileMap, const CElementwiseOperation c_element_op,
typename ComputeBasePtrOfStridedBatch, const ushort p_dropout_in_16bits,
typename C0MatrixMask, const GemmAccDataType p_dropout_rescale,
bool HasMainKBlockLoop> const unsigned long long seed,
__global__ void const unsigned long long offset)
#if CK_USE_LAUNCH_BOUNDS {
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#endif __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const index_t block_id = get_block_1d_id();
const DataType* __restrict__ p_b_grid, const index_t global_thread_id = get_thread_global_1d_id();
ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, ck::philox ph(seed, global_thread_id, offset);
const DataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
const DataType* __restrict__ p_ygrad_grid, cast_pointer_to_generic_address_space(group_kernel_args));
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, index_t left = 0;
DataType* __restrict__ p_vgrad_grid, index_t right = group_count;
const AElementwiseOperation a_element_op, index_t group_id = index_t((left + right) / 2);
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, while(
const B1ElementwiseOperation b1_element_op, (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_)))
const CElementwiseOperation c_element_op, {
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, if(block_id < arg_ptr[group_id].block_start_)
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, {
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 right = group_id;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, }
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, else
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock {
c_grid_desc_mblock_mperblock_nblock_nperblock, left = group_id;
const LSEGridDescriptor_M lse_grid_desc_m, }
const VGradGridDescriptor_N_O vgrad_grid_desc_n_o, group_id = index_t((left + right) / 2);
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1, }
const Block2CTileMap block_2_ctile_map,
const index_t batch_count, // per-group batch offset
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const C0MatrixMask c0_matrix_mask, const index_t g_idx = __builtin_amdgcn_readfirstlane(
const float p_dropout, (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const unsigned long long seed,
const unsigned long long offset) const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
{ static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const index_t num_blocks_per_batch = const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
// offsets static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); //unsigned short* p_z_grid_in = //
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( // (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); // : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout>(
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( arg_ptr[group_id].p_a_grid_ + a_batch_offset,
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); arg_ptr[group_id].p_b_grid_ + b_batch_offset,
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
const index_t global_thread_id = get_thread_global_1d_id(); : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
ck::philox ph(seed, global_thread_id, offset); arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); p_shared,
a_element_op,
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, b_element_op,
p_b_grid + b_batch_offset, acc_element_op,
z_matrix_ptr, b1_element_op,
p_b1_grid + b1_batch_offset, c_element_op,
p_c_grid + c_batch_offset, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
p_lse_grid + lse_batch_offset, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
p_ygrad_grid + c_batch_offset, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
p_qgrad_grid + a_batch_offset, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
p_kgrad_grid + b_batch_offset, arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, ////////
p_vgrad_grid + b1_batch_offset, arg_ptr[group_id].lse_grid_desc_m_,
p_shared, arg_ptr[group_id].block_2_ctile_map_,
a_element_op, arg_ptr[group_id].c0_matrix_mask_,
b_element_op, p_dropout_in_16bits,
acc_element_op, p_dropout_rescale,
b1_element_op, ph);
c_element_op, #else
a_grid_desc_ak0_m_ak1, ignore = group_kernel_args;
b_grid_desc_bk0_n_bk1, ignore = group_count;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, ignore = a_element_op;
b1_grid_desc_bk0_n_bk1, ignore = b_element_op;
c_grid_desc_mblock_mperblock_nblock_nperblock, ignore = acc_element_op;
lse_grid_desc_m, ignore = b1_element_op;
vgrad_grid_desc_n_o, ignore = c_element_op;
ygrad_grid_desc_m0_o_m1, #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
block_2_ctile_map, }
c0_matrix_mask,
p_dropout, // Computes C = A * B0 * B1
ph); // ^^^^^^ (Acc0)
#else // ^^^^^^^^^^^ (Acc1)
ignore = p_a_grid; template <index_t NumDimG,
ignore = p_b_grid; index_t NumDimM,
ignore = p_b1_grid; index_t NumDimN,
ignore = p_c_grid; index_t NumDimK,
ignore = a_element_op; index_t NumDimO, // NumDimGemm1N
ignore = b_element_op; typename ADataType,
ignore = acc_element_op; typename BDataType,
ignore = b1_element_op; typename B1DataType,
ignore = c_element_op; typename CDataType,
ignore = a_grid_desc_ak0_m_ak1; typename ZDataType,
ignore = b_grid_desc_bk0_n_bk1; typename LSEDataType,
ignore = b1_grid_desc_bk0_n_bk1; typename Acc0BiasDataType,
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; typename Acc1BiasDataType,
ignore = block_2_ctile_map; typename GemmAccDataType,
ignore = batch_count; typename CShuffleDataType,
ignore = compute_base_ptr_of_batch; typename AElementwiseOperation,
ignore = c0_matrix_mask; typename BElementwiseOperation,
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) typename AccElementwiseOperation,
} typename B1ElementwiseOperation,
typename CElementwiseOperation,
// Computes C = A * B0 * B1 GemmSpecialization GemmSpec,
// ^^^^^^ (Acc0) TensorSpecialization ASpec,
// ^^^^^^^^^^^ (Acc1) TensorSpecialization BSpec,
template <index_t NumDimG, TensorSpecialization B1Spec,
index_t NumDimM, TensorSpecialization CSpec,
index_t NumDimN, index_t NumGemmKPrefetchStage,
index_t NumDimK, index_t BlockSize,
index_t NumDimO, // NumDimGemm1N index_t MPerBlock,
typename DataType, index_t NPerBlock, // Gemm0NPerBlock
typename ZDataType, index_t KPerBlock, // Gemm0KPerBlock
typename LSEDataType, index_t Gemm1NPerBlock,
typename Acc0BiasDataType, index_t Gemm1KPerBlock,
typename Acc1BiasDataType, index_t AK1,
typename GemmAccDataType, index_t BK1,
typename CShuffleDataType, index_t B1K1,
typename AElementwiseOperation, index_t MPerXDL,
typename BElementwiseOperation, index_t NPerXDL,
typename AccElementwiseOperation, index_t MXdlPerWave,
typename B1ElementwiseOperation, index_t NXdlPerWave,
typename CElementwiseOperation, index_t Gemm1NXdlPerWave,
GemmSpecialization GemmSpec, typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
TensorSpecialization ASpec, typename ABlockTransferThreadClusterArrangeOrder,
TensorSpecialization BSpec, typename ABlockTransferSrcAccessOrder,
TensorSpecialization B1Spec, index_t ABlockTransferSrcVectorDim,
TensorSpecialization CSpec, index_t ABlockTransferSrcScalarPerVector,
index_t NumGemmKPrefetchStage, index_t ABlockTransferDstScalarPerVector_AK1,
index_t BlockSize, bool ABlockLdsExtraM,
index_t MPerBlock, typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
index_t NPerBlock, // Gemm0NPerBlock typename BBlockTransferThreadClusterArrangeOrder,
index_t KPerBlock, // Gemm0KPerBlock typename BBlockTransferSrcAccessOrder,
index_t Gemm1NPerBlock, index_t BBlockTransferSrcVectorDim,
index_t Gemm1KPerBlock, index_t BBlockTransferSrcScalarPerVector,
index_t AK1, index_t BBlockTransferDstScalarPerVector_BK1,
index_t BK1, bool BBlockLdsExtraN,
index_t B1K1, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
index_t MPerXDL, typename B1BlockTransferThreadClusterArrangeOrder,
index_t NPerXDL, typename B1BlockTransferSrcAccessOrder,
index_t MXdlPerWave, index_t B1BlockTransferSrcVectorDim,
index_t NXdlPerWave, index_t B1BlockTransferSrcScalarPerVector,
index_t Gemm1NXdlPerWave, index_t B1BlockTransferDstScalarPerVector_BK1,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1, bool B1BlockLdsExtraN,
typename ABlockTransferThreadClusterArrangeOrder, index_t CShuffleMXdlPerWavePerShuffle,
typename ABlockTransferSrcAccessOrder, index_t CShuffleNXdlPerWavePerShuffle,
index_t ABlockTransferSrcVectorDim, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t ABlockTransferSrcScalarPerVector, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
index_t ABlockTransferDstScalarPerVector_AK1, MaskingSpecialization MaskingSpec,
bool ABlockLdsExtraM, LoopScheduler LoopSched = LoopScheduler::Default>
typename BBlockTransferThreadClusterLengths_BK0_N_BK1, struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
typename BBlockTransferThreadClusterArrangeOrder, : public DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
typename BBlockTransferSrcAccessOrder, NumDimM,
index_t BBlockTransferSrcVectorDim, NumDimN,
index_t BBlockTransferSrcScalarPerVector, NumDimK,
index_t BBlockTransferDstScalarPerVector_BK1, NumDimO,
bool BBlockLdsExtraN, ADataType,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, BDataType,
typename B1BlockTransferThreadClusterArrangeOrder, B1DataType,
typename B1BlockTransferSrcAccessOrder, CDataType,
index_t B1BlockTransferSrcVectorDim, ZDataType,
index_t B1BlockTransferSrcScalarPerVector, LSEDataType,
index_t B1BlockTransferDstScalarPerVector_BK1, Acc0BiasDataType,
bool B1BlockLdsExtraN, Acc1BiasDataType,
index_t CShuffleMXdlPerWavePerShuffle, AElementwiseOperation,
index_t CShuffleNXdlPerWavePerShuffle, BElementwiseOperation,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, AccElementwiseOperation,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, B1ElementwiseOperation,
MaskingSpecialization MaskingSpec, CElementwiseOperation,
LoopScheduler LoopSched = LoopScheduler::Default> MaskingSpec>
struct DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle {
: public BaseOperator // TODO inherit atten bwd op once API stablizes static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
{ "Number of dimension must be greater than 0");
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); // TODO ANT: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); #if 0
// TODO ANT: use alias
#if 0 static constexpr index_t NumDimGemm0M = NumDimM;
// TODO: use alias static constexpr index_t NumDimGemm0N = NumDimN;
static constexpr index_t NumDimGemm0M = NumDimM; static constexpr index_t NumDimGemm0K = NumDimK;
static constexpr index_t NumDimGemm0N = NumDimN; static constexpr index_t NumDimGemm1M = NumDimM;
static constexpr index_t NumDimGemm0K = NumDimK; static constexpr index_t NumDimGemm1N = NumDimO;
static constexpr index_t NumDimGemm1M = NumDimM; static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1N = NumDimO; #endif
static constexpr index_t NumDimGemm1K = NumDimN;
#endif using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle;
using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermuteTrain<NumDimG,
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle; NumDimM,
NumDimN,
static constexpr auto I0 = Number<0>{}; NumDimK,
static constexpr auto I1 = Number<1>{}; NumDimO,
static constexpr auto I2 = Number<2>{}; ADataType,
BDataType,
static constexpr index_t Q_K1 = 8; B1DataType,
static constexpr index_t K_K1 = 8; CDataType,
static constexpr index_t V_N1 = 2; ZDataType,
LSEDataType,
static constexpr index_t Q_M1 = 2; Acc0BiasDataType,
static constexpr index_t K_N1 = 2; Acc1BiasDataType,
static constexpr index_t V_O1 = 8; AElementwiseOperation,
static constexpr index_t Y_O1 = 8; BElementwiseOperation,
static constexpr index_t Y_M1 = 2; AccElementwiseOperation,
B1ElementwiseOperation,
static constexpr auto padder = GemmGemmPadder<GemmSpec, CElementwiseOperation,
Number<MPerBlock>, MaskingSpec>::ProblemDesc;
Number<NPerBlock>,
Number<KPerBlock>, static constexpr auto I0 = Number<0>{};
Number<Gemm1NPerBlock>>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>, using Transform = TransformBatchedContractionContractionToBatchedGemmGemm<
Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>, Sequence<NumDimG, NumDimM, NumDimN, NumDimK, NumDimO>,
GemmSpec, Sequence<MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock>,
ASpec, GemmSpec,
BSpec, ASpec,
B1Spec, BSpec,
CSpec>; B1Spec,
CSpec>;
/*
Descriptors for inputs: static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
Q, K, V, Y, dY, per-row softmax stats {
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Descriptors for outputs: Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{});
dQ, dK, dV }
*/ static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
// Q in Gemm A position {
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_gs_ms_ks_lengths_vec, return Transform::MakeB0GridDescriptor_BK0_N_BK1(
const std::vector<index_t>& a_gs_ms_ks_strides_vec) Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
{ Number<BK1>{});
return Transform::MakeAGridDescriptor_AK0_M_AK1( }
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{}); static auto
} MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
// K in Gemm B0 position {
static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector<index_t>& b_gs_ns_ks_lengths_vec, return Transform::MakeB1GridDescriptor_BK0_N_BK1(
const std::vector<index_t>& b_gs_ns_ks_strides_vec) Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
{ b1_gs_gemm1ns_gemm1ks_strides_vec),
return Transform::MakeB0GridDescriptor_BK0_N_BK1( Number<B1K1>{});
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), }
Number<BK1>{});
} static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
// V in Gemm B1 position {
static auto return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec);
MakeB1GridDescriptor_BK0_N_BK1(const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths_vec, }
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ static auto MakeLSEGridDescriptor_M(index_t MRaw)
return Transform::MakeB1GridDescriptor_BK0_N_BK1( {
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
b1_gs_gemm1ns_gemm1ks_strides_vec),
Number<B1K1>{}); const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
} const auto MPad = M - MRaw;
// if constexpr(GemmSpec == GemmSpecialization::MPadding ||
// dV = P^T * dY GemmSpec == GemmSpecialization::MNPadding ||
// GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
// VGrad in Gemm C position {
static auto MakeVGradGridDescriptor_N_O(const std::vector<index_t>& v_gs_os_ns_lengths_vec, // pad M
const std::vector<index_t>& v_gs_os_ns_strides_vec) return transform_tensor_descriptor(lse_grid_desc_mraw,
{ make_tuple(make_right_pad_transform(MRaw, MPad)),
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. make_tuple(Sequence<0>{}),
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce make_tuple(Sequence<0>{}));
// transformation overhead }
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to else
// extract subsequence and shuffle them. {
const index_t num_dims = NumDimG + NumDimN + NumDimO; // not pad M
return lse_grid_desc_mraw;
// 0, 1, .. NumDimG - 1 }
std::vector<index_t> gs_ids(NumDimG); }
std::iota(gs_ids.begin(), gs_ids.end(), 0);
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1 using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
std::vector<index_t> os_ids(NumDimO); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
std::iota(os_ids.begin(), os_ids.end(), NumDimG); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1 using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
std::vector<index_t> ids_old2new; using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end()); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
constexpr static auto make_MaskOutPredicate()
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); {
for(int i = 0; i < num_dims; i++) if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
{ {
index_t id_new = ids_old2new[i]; return MaskDisabledPredicate{};
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; }
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
} {
return MaskOutUpperTrianglePredicate{};
const auto vgrad_desc_nraw_oraw = }
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( }
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>;
.second;
struct ComputeBasePtrOfStridedBatch
return PadTensorDescriptor(vgrad_desc_nraw_oraw, {
make_tuple(NPerBlock, Gemm1NPerBlock), ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
Sequence<padder.PadN, padder.PadO>{}); const BGridDesc_G_N_K& b_grid_desc_g_n_k,
} const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n,
template <typename YGridDesc_M_O> const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
static auto MakeYGradGridDescriptor_M0_O_M1(const YGridDesc_M_O& ygrad_grid_desc_m_o) index_t BatchStrideLSE)
{ : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
const auto M = ygrad_grid_desc_m_o.GetLength(I0); b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
const auto O = ygrad_grid_desc_m_o.GetLength(I1); b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
const auto Y_M0 = M / Y_M1; z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
return transform_tensor_descriptor( {
ygrad_grid_desc_m_o, }
make_tuple(make_unmerge_transform(make_tuple(Y_M0, Y_M1)),
make_pass_through_transform(O)), __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
make_tuple(Sequence<0>{}, Sequence<1>{}), {
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
// __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
// dP = dY * V^T {
// return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths_vec, __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
const std::vector<index_t>& y_gs_ms_os_strides_vec) {
{ return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
return Transform::MakeAGridDescriptor_AK0_M_AK1( }
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths_vec, y_gs_ms_os_strides_vec),
Number<Y_O1>{}); __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
} {
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
// V in Gemm B position }
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths_vec,
const std::vector<index_t>& v_gs_os_ns_strides_vec) __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major. return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce }
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to __host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
// extract subsequence and shuffle them. {
const index_t num_dims = NumDimG + NumDimN + NumDimO; return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG); private:
std::iota(gs_ids.begin(), gs_ids.end(), 0); AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1 B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
std::vector<index_t> os_ids(NumDimO); CGridDesc_G_M_N c_grid_desc_g_m_n_;
std::iota(os_ids.begin(), os_ids.end(), NumDimG); ZGridDesc_G_M_N z_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1 };
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO); // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle<
std::vector<index_t> ids_old2new; ADataType, // TODO: distinguish A/B datatype
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end()); GemmAccDataType,
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end()); CShuffleDataType,
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end()); CDataType,
LSEDataType,
std::vector<index_t> v_gs_ns_os_lengths_vec(num_dims), v_gs_ns_os_strides_vec(num_dims); AElementwiseOperation,
for(int i = 0; i < num_dims; i++) BElementwiseOperation,
{ AccElementwiseOperation,
index_t id_new = ids_old2new[i]; B1ElementwiseOperation,
v_gs_ns_os_lengths_vec[i] = v_gs_os_ns_lengths_vec[id_new]; CElementwiseOperation,
v_gs_ns_os_strides_vec[i] = v_gs_os_ns_strides_vec[id_new]; InMemoryDataOperationEnum::Set,
} AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
const auto v_grid_desc_nraw_oraw = B1GridDesc_BK0_N_BK1,
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>( CGridDesc_M_N,
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) ZGridDesc_M_N,
.second; LSEGridDesc_M,
NumGemmKPrefetchStage,
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw, BlockSize,
make_tuple(NPerBlock, Gemm1NPerBlock), MPerBlock,
Sequence<padder.PadN, padder.PadO>{}); NPerBlock,
KPerBlock,
// N_O to O0_N_O1; to refactor Gemm1NPerBlock,
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{}); Gemm1KPerBlock,
} AK1,
BK1,
// Z in Gemm0 C position B1K1,
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec, MPerXDL,
const std::vector<index_t>& z_gs_ms_ns_strides_vec) NPerXDL,
{ MXdlPerWave,
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths_vec, z_gs_ms_ns_strides_vec); NXdlPerWave,
} Gemm1NXdlPerWave,
// ABlockTransferThreadClusterLengths_AK0_M_AK1,
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i) ABlockTransferThreadClusterArrangeOrder,
// ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
// ABlockTransferSrcScalarPerVector,
// dQ = alpha * dS * K ABlockTransferDstScalarPerVector_AK1,
// true,
ABlockLdsExtraM,
// QGrad in Gemm C position BBlockTransferThreadClusterLengths_BK0_N_BK1,
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths_vec, BBlockTransferThreadClusterArrangeOrder,
const std::vector<index_t>& q_gs_ms_ks_strides_vec) BBlockTransferSrcAccessOrder,
{ BBlockTransferSrcVectorDim,
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths_vec, q_gs_ms_ks_strides_vec); BBlockTransferSrcScalarPerVector,
} BBlockTransferDstScalarPerVector_BK1,
true,
// BBlockLdsExtraN,
// dK = alpha * dS^T * Q B1BlockTransferThreadClusterLengths_BK0_N_BK1,
// B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
// KGrad in Gemm C position B1BlockTransferSrcVectorDim,
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths_vec, B1BlockTransferSrcScalarPerVector,
const std::vector<index_t>& k_gs_ns_ks_strides_vec) B1BlockTransferDstScalarPerVector_BK1,
{ false,
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); B1BlockLdsExtraN,
} CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
static auto MakeLSEGridDescriptor_M(index_t MRaw) CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
{ CShuffleBlockTransferScalarPerVector_NPerBlock,
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); LoopSched,
Transform::matrix_padder.PadN,
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>;
const auto MPad = M - MRaw;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding || struct GroupKernelArg
GemmSpec == GemmSpecialization::MKPadding || {
GemmSpec == GemmSpecialization::MNKPadding) // pointers
{ const ADataType* p_a_grid_;
// pad M const BDataType* p_b_grid_;
return transform_tensor_descriptor(lse_grid_desc_mraw, const B1DataType* p_b1_grid_;
make_tuple(make_right_pad_transform(MRaw, MPad)), CDataType* p_c_grid_;
make_tuple(Sequence<0>{}), ZDataType* p_z_grid_;
make_tuple(Sequence<0>{})); LSEDataType* p_lse_grid_;
}
else // tensor descriptors for block/thread-wise copy
{ AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
// not pad M BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
return lse_grid_desc_mraw; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
} typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
} c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); ZGridDesc_M_N z_grid_desc_m_n_;
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); LSEGridDesc_M lse_grid_desc_m_;
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); // batch & stride
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); index_t num_blocks_per_batch_;
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); // check C0 masking and padding
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); C0MatrixMask c0_matrix_mask_;
using VGradGridDesc_N_O = decltype(MakeVGradGridDescriptor_N_O({}, {})); // block-to-c-tile map
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{})); Block2CTileMap block_2_ctile_map_;
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
index_t block_start_, block_end_;
constexpr static auto make_MaskOutPredicate() };
{
if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) struct GroupDeviceArg
{ {
return MaskDisabledPredicate{}; // lengths for the last dimensions of overall problem for sanity check of vector load/store
} std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_;
else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
{ // strides for the last dimensions of each tensor for sanity check of vector load/store
return MaskOutUpperTrianglePredicate{}; std::vector<index_t> a_mz_kz_strides_;
} std::vector<index_t> b_nz_kz_strides_;
} std::vector<index_t> b1_nz_kz_strides_;
using C0MatrixMask = C0MatrixMask_impl<decltype(make_MaskOutPredicate())>; std::vector<index_t> c_mz_gemm1nz_strides_;
struct ComputeBasePtrOfStridedBatch // for gridwise gemm check
{ CGridDesc_M_N c_grid_desc_m_n_;
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, };
const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, // Argument
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, // FIXME: constness
const CGridDesc_G_M_N& c_grid_desc_g_m_n, struct Argument : public BaseArgument
index_t BatchStrideLSE) {
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), Argument(std::vector<const void*> p_a_vec,
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), std::vector<const void*> p_b_vec,
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), std::vector<const void*> p_b1_vec,
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), std::vector<void*> p_c_vec,
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), std::vector<void*> p_z_vec,
BatchStrideLSE_(BatchStrideLSE) std::vector<void*> p_lse_vec,
{ std::vector<std::vector<const void*>> p_acc0_biases_vec,
} std::vector<std::vector<const void*>> p_acc1_biases_vec,
std::vector<ProblemDesc> problem_desc_vec,
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const AElementwiseOperation a_element_op,
{ BElementwiseOperation b_element_op,
return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); AccElementwiseOperation acc_element_op,
} B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const float p_dropout,
{ std::tuple<unsigned long long, unsigned long long> seeds)
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); : a_element_op_{a_element_op},
} b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const b1_element_op_{b1_element_op},
{ c_element_op_{c_element_op}
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); {
} // TODO ANT: implement bias addition
group_count_ = problem_desc_vec.size();
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{ if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
} {
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const }
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size()))
} {
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const }
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); grid_size_ = 0;
}
for(std::size_t i = 0; i < group_count_; i++)
private: {
AGridDesc_G_M_K a_grid_desc_g_m_k_; const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
BGridDesc_G_N_K b_grid_desc_g_n_k_; const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
ZGridDesc_G_M_N z_grid_desc_g_m_n_; const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
CGridDesc_G_M_N c_grid_desc_g_m_n_; const auto p_z_grid = static_cast<ZDataType*>(p_z_vec[i]);
const auto p_lse_grid = static_cast<LSEDataType*>(p_lse_vec[i]);
index_t BatchStrideLSE_;
}; const auto& problem_desc = problem_desc_vec[i];
// GridwiseGemm const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
DataType, // TODO: distinguish A/B datatype const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
LSEDataType, problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
GemmAccDataType, const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
CShuffleDataType, problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
AElementwiseOperation, const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(
BElementwiseOperation, problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
AccElementwiseOperation, const auto z_grid_desc_m_n = MakeZGridDescriptor_M_N(
B1ElementwiseOperation, problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
CElementwiseOperation, const auto lse_grid_desc_m =
InMemoryDataOperationEnum::Set, DeviceOp::MakeLSEGridDescriptor_M(problem_desc.lse_gs_ms_lengths[NumDimG]);
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K(
ZGridDesc_M_N, problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
B1GridDesc_BK0_N_BK1, const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
YGridDesc_M_O, problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides);
LSEGridDesc_M, const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
NumGemmKPrefetchStage, problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
BlockSize, const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
MPerBlock, problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides);
NPerBlock, const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
KPerBlock, problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
Gemm1NPerBlock,
Gemm1KPerBlock, const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
AK1, GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
BK1, c_grid_desc_m_n);
B1K1,
MPerXDL, //typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
NPerXDL, // z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
MXdlPerWave,
NXdlPerWave, auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
Gemm1NXdlPerWave, GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
ABlockTransferThreadClusterLengths_AK0_M_AK1, z_grid_desc_m_n);
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, const index_t BlockStart = grid_size_;
ABlockTransferSrcVectorDim, const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
ABlockTransferSrcScalarPerVector, const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
ABlockTransferDstScalarPerVector_AK1, const index_t grid_size_grp =
true, block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count;
ABlockLdsExtraM, const index_t BlockEnd = grid_size_ + grid_size_grp;
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, // batch stride
BBlockTransferSrcAccessOrder, const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
BBlockTransferSrcVectorDim, a_grid_desc_g_m_k,
BBlockTransferSrcScalarPerVector, b_grid_desc_g_n_k,
BBlockTransferDstScalarPerVector_BK1, b1_grid_desc_g_n_k,
true, c_grid_desc_g_m_n,
BBlockLdsExtraN, z_grid_desc_g_m_n,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder, // C0 mask
B1BlockTransferSrcVectorDim, const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1));
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1, grid_size_ += grid_size_grp;
false,
B1BlockLdsExtraN, // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and
CShuffleMXdlPerWavePerShuffle, // so on
CShuffleNXdlPerWavePerShuffle, if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias &&
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias &&
CShuffleBlockTransferScalarPerVector_NPerBlock, problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias &&
LoopSched, problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias))
Transform::matrix_padder.PadN, {
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; throw std::runtime_error(
"wrong! number of biases in function argument does not "
// Argument "match that in template argument");
struct Argument : public BaseArgument }
{
Argument( group_kernel_args_.push_back({p_a_grid,
const DataType* p_a_grid, p_b_grid,
const DataType* p_b_grid, p_b1_grid,
ZDataType* p_z_grid, p_c_grid,
const DataType* p_b1_grid, p_z_grid,
const DataType* p_c_grid, // for dS p_lse_grid,
const LSEDataType* p_lse_grid, a_grid_desc_ak0_m_ak1,
const DataType* p_ygrad_grid, b_grid_desc_bk0_n_bk1,
DataType* p_qgrad_grid, b1_grid_desc_bk0_n_bk1,
DataType* p_kgrad_grid, c_grid_desc_mblock_mperblock_nblock_nperblock,
DataType* p_vgrad_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const std::array<void*, NumAcc0Bias> p_acc0_biases, z_grid_desc_m_n,
const std::array<void*, NumAcc1Bias> p_acc1_biases, lse_grid_desc_m,
const std::vector<index_t>& a_gs_ms_ks_lengths, block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
const std::vector<index_t>& a_gs_ms_ks_strides, compute_base_ptr_of_batch,
const std::vector<index_t>& b_gs_ns_ks_lengths, c0_matrix_mask,
const std::vector<index_t>& b_gs_ns_ks_strides, block_2_ctile_map,
const std::vector<index_t>& z_gs_ms_ns_lengths, BlockStart,
const std::vector<index_t>& z_gs_ms_ns_strides, BlockEnd});
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides group_device_args_.push_back(
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
const std::vector<index_t>& lse_gs_ms_lengths, problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1],
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]},
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
const std::array<std::vector<ck::index_t>, NumAcc1Bias> problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths {problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1],
const std::array<std::vector<ck::index_t>, NumAcc1Bias> problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]},
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides {problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1],
AElementwiseOperation a_element_op, problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]},
BElementwiseOperation b_element_op, {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
AccElementwiseOperation acc_element_op, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
B1ElementwiseOperation b1_element_op, c_grid_desc_m_n});
CElementwiseOperation c_element_op, }
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) is_dropout_ = p_dropout > 0.0; //
: p_a_grid_{p_a_grid}, p_dropout_ = 1.f - p_dropout;
p_b_grid_{p_b_grid}, p_dropout_in_16bits_ = uint16_t(std::floor(p_dropout_ * 65535.0));
p_z_grid_{p_z_grid}, p_dropout_ = 1.f / p_dropout_;
p_b1_grid_{p_b1_grid}, p_dropout_rescale_ = type_convert<GemmAccDataType>(p_dropout_);
p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid}, seed_ = std::get<0>(seeds);
p_ygrad_grid_{p_ygrad_grid}, offset_ = std::get<1>(seeds);
p_qgrad_grid_{p_qgrad_grid}, }
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, std::vector<GroupKernelArg> group_kernel_args_;
a_grid_desc_ak0_m_ak1_{ std::vector<GroupDeviceArg> group_device_args_;
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ std::size_t group_count_;
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, index_t grid_size_;
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( AElementwiseOperation a_element_op_;
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, BElementwiseOperation b_element_op_;
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, AccElementwiseOperation acc_element_op_;
c_gs_ms_gemm1ns_strides)}, B1ElementwiseOperation b1_element_op_;
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, CElementwiseOperation c_element_op_;
vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, float p_dropout_;
ygrad_grid_desc_m0_o_m1_{DeviceOp::MakeYGradGridDescriptor_M0_O_M1(y_grid_desc_m_o_)}, ushort p_dropout_in_16bits_;
// batch offsets unsigned long long seed_;
a_grid_desc_g_m_k_{ unsigned long long offset_;
Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, GemmAccDataType p_dropout_rescale_;
b_grid_desc_g_n_k_{ bool is_dropout_;
Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, };
b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, // Invoker
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, struct Invoker : public BaseInvoker
c_gs_ms_gemm1ns_strides)}, {
z_grid_desc_g_m_n_{ using Argument = DeviceOp::Argument;
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(y_grid_desc_m_o_)}, {
a_element_op_{a_element_op}, if(!DeviceOp::IsSupportedArgument(arg))
b_element_op_{b_element_op}, {
acc_element_op_{acc_element_op}, throw std::runtime_error("wrong! unsupported argument");
b1_element_op_{b1_element_op}, }
c_element_op_{c_element_op},
c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, bool all_has_main_k_block_loop = true;
raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], bool some_has_main_k_block_loop = false;
b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], for(std::size_t i = 0; i < arg.group_count_; i++)
b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], {
b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]}, const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1], all_has_main_k_block_loop &= y;
b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, some_has_main_k_block_loop |= y;
b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1], }
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], hipGetErrorString(hipMemcpy(arg.p_workspace_,
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, arg.group_kernel_args_.data(),
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
compute_base_ptr_of_batch_{ hipMemcpyHostToDevice));
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_, float ave_time = 0;
z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, auto launch_kernel = [&](auto has_main_k_block_loop_, auto is_dropout_) {
c_grid_desc_g_m_n_, const auto kernel =
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())} kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
{ GemmAccDataType,
// TODO: implement bias addition GroupKernelArg,
ignore = p_acc0_biases; AElementwiseOperation,
ignore = p_acc1_biases; BElementwiseOperation,
ignore = acc0_biases_gs_ms_ns_lengths; AccElementwiseOperation,
ignore = acc0_biases_gs_ms_ns_strides; B1ElementwiseOperation,
ignore = acc1_biases_gs_ms_gemm1ns_lengths; CElementwiseOperation,
ignore = acc1_biases_gs_ms_gemm1ns_strides; has_main_k_block_loop_,
is_dropout_>;
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, return launch_and_time_kernel(
b1_grid_desc_bk0_n_bk1_, stream_config,
y_grid_desc_m_o_, kernel,
block_2_ctile_map_)) dim3(arg.grid_size_),
{ dim3(BlockSize),
y_grid_desc_mblock_mperblock_oblock_operblock_ = 0,
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( cast_pointer_to_constant_address_space(arg.p_workspace_),
y_grid_desc_m_o_); arg.group_count_,
} arg.a_element_op_,
arg.b_element_op_,
p_dropout_ = 1.f - p_drop; arg.acc_element_op_,
float rp_dropout_ = 1.f / p_dropout_; arg.b1_element_op_,
acc_element_op_.Append(rp_dropout_); arg.c_element_op_,
arg.p_dropout_in_16bits_,
seed_ = std::get<0>(seeds); arg.p_dropout_rescale_,
offset_ = std::get<1>(seeds); arg.seed_,
arg.offset_);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = };
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
// Print(); // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
} // to concern Gemm0's loop
if(all_has_main_k_block_loop)
void Print() const {
{ if(arg.is_dropout_)
std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", " {
<< a_grid_desc_g_m_k_.GetLength(I1) << ", " ave_time = launch_kernel(integral_constant<bool, true>{},
<< a_grid_desc_g_m_k_.GetLength(I2) << '\n'; integral_constant<bool, true>{});
// a_grid_desc_g_m_k_.Print(); }
std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", " else
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " {
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; ave_time = launch_kernel(integral_constant<bool, true>{},
// b_grid_desc_g_n_k_.Print(); integral_constant<bool, false>{});
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " }
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " }
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; else if(!some_has_main_k_block_loop)
// b1_grid_desc_g_n_k_.Print(); {
std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", " if(arg.is_dropout_)
<< c_grid_desc_g_m_n_.GetLength(I1) << ", " {
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n'; ave_time = launch_kernel(integral_constant<bool, false>{},
// c_grid_desc_g_m_n_.Print(); integral_constant<bool, true>{});
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", " }
<< vgrad_grid_desc_n_o_.GetLength(I1) << '\n'; else
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) {
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", " ave_time = launch_kernel(integral_constant<bool, false>{},
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n'; integral_constant<bool, false>{});
} }
}
// pointers else
const DataType* p_a_grid_; {
const DataType* p_b_grid_; throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
ZDataType* p_z_grid_; "has_main_k_block_loop or no_main_k_block_loop");
const DataType* p_b1_grid_; }
const DataType* p_c_grid_;
const LSEDataType* p_lse_grid_; return ave_time;
const DataType* p_ygrad_grid_; }
DataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; // polymorphic
DataType* p_vgrad_grid_; float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
// tensor descriptor {
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; }
ZGridDesc_M_N z_grid_desc_m_n_; };
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; static constexpr bool IsValidCompilationParameter()
LSEGridDesc_M lse_grid_desc_m_; {
VGradGridDesc_N_O vgrad_grid_desc_n_o_; // TODO: properly implement this check
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_; return true;
}
// batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_; static bool IsSupportedArgument(const Argument& arg)
BGridDesc_G_N_K b_grid_desc_g_n_k_; {
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
CGridDesc_G_M_N c_grid_desc_g_m_n_; {
ZGridDesc_G_M_N z_grid_desc_g_m_n_; return false;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock }
y_grid_desc_mblock_mperblock_oblock_operblock_;
// TODO ANT: Check if tensor specialization & strides mismatch
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; for(std::size_t i = 0; i < arg.group_count_; i++)
{
// element-wise op const auto& kernel_arg = arg.group_kernel_args_[i];
AElementwiseOperation a_element_op_; const auto& device_arg = arg.group_device_args_[i];
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_; // Check if C permute dimension matches GEMM + GEMM shape
B1ElementwiseOperation b1_element_op_; const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
CElementwiseOperation c_element_op_; const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
// check C0 masking and padding const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
C0MatrixMask c0_matrix_mask_; if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{
// For robust IsSupportedArgument() check return false;
std::vector<index_t> raw_lengths_mz_nz_kz_gemm1nz_; }
std::vector<index_t> a_mz_kz_strides_;
std::vector<index_t> b_nz_kz_strides_; // Check if having main loop
std::vector<index_t> b1_nz_kz_strides_; const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
std::vector<index_t> c_mz_gemm1nz_strides_; kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
index_t batch_count_; all_has_main_k_block_loop &= y;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; some_has_main_k_block_loop |= y;
float p_dropout_; // Note: we need raw lengths since threadwise copy can not handle vector load when
unsigned long long seed_; // part of vector is out of bounds
unsigned long long offset_; const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0];
}; const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1];
const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
// Invoker const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3];
struct Invoker : public BaseInvoker
{ // Check scalar per vector requirement
using Argument = DeviceOp::Argument; const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw;
{ const auto c_extent_lowest = Gemm1NzRaw;
if(!DeviceOp::IsSupportedArgument(arg))
{ if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
throw std::runtime_error("wrong! unsupported argument"); b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
} b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
const index_t grid_size = {
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_; return false;
}
// Gemm0_K
const auto K = // Check vector load/store requirement
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
float ave_time = 0; : device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
auto launch_kernel = [&](auto has_main_k_block_loop_) { ? device_arg.b_nz_kz_strides_[1]
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v2< : device_arg.b_nz_kz_strides_[0];
GridwiseGemm, const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
DataType, ? device_arg.b1_nz_kz_strides_[1]
ZDataType, : device_arg.b1_nz_kz_strides_[0];
LSEDataType, const auto c_stride_lowest =
AElementwiseOperation, device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be
BElementwiseOperation, // contiguous
AccElementwiseOperation,
B1ElementwiseOperation, if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 ||
CElementwiseOperation, c_stride_lowest == 1))
DeviceOp::AGridDesc_AK0_M_AK1, {
DeviceOp::BGridDesc_BK0_N_BK1, return false;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, }
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_,
DeviceOp::LSEGridDesc_M, kernel_arg.b_grid_desc_bk0_n_bk1_,
DeviceOp::VGradGridDesc_N_O, kernel_arg.b1_grid_desc_bk0_n_bk1_,
DeviceOp::YGradGridDesc_M0_O_M1, device_arg.c_grid_desc_m_n_,
typename GridwiseGemm::DefaultBlock2CTileMap, kernel_arg.block_2_ctile_map_))
ComputeBasePtrOfStridedBatch, {
C0MatrixMask, return false;
has_main_k_block_loop_>; }
}
return launch_and_time_kernel(stream_config,
kernel, // all gemm problems have to simultaneously meet has_main_k_block_loop or
dim3(grid_size), // no_main_k_block_loop
dim3(BlockSize), if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop))
0, {
arg.p_a_grid_, return false;
arg.p_b_grid_, }
arg.p_z_grid_,
arg.p_b1_grid_, return true;
arg.p_c_grid_, }
arg.p_lse_grid_,
arg.p_ygrad_grid_, // polymorphic
arg.p_qgrad_grid_, bool IsSupportedArgument(const BaseArgument* p_arg) override
arg.p_kgrad_grid_, {
arg.p_vgrad_grid_, return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
arg.a_element_op_, }
arg.b_element_op_,
arg.acc_element_op_, static auto MakeArgument(std::vector<const void*> p_a_vec,
arg.b1_element_op_, std::vector<const void*> p_b_vec,
arg.c_element_op_, std::vector<const void*> p_b1_vec,
arg.a_grid_desc_ak0_m_ak1_, std::vector<void*> p_c_vec,
arg.b_grid_desc_bk0_n_bk1_, std::vector<void*> p_z_vec,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, std::vector<void*> p_lse_vec,
arg.b1_grid_desc_bk0_n_bk1_, std::vector<std::vector<const void*>> p_acc0_biases_vec,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, std::vector<std::vector<const void*>> p_acc1_biases_vec,
arg.lse_grid_desc_m_, std::vector<ProblemDesc> problem_desc_vec,
arg.vgrad_grid_desc_n_o_, AElementwiseOperation a_element_op,
arg.ygrad_grid_desc_m0_o_m1_, BElementwiseOperation b_element_op,
arg.block_2_ctile_map_, AccElementwiseOperation acc_element_op,
arg.batch_count_, B1ElementwiseOperation b1_element_op,
arg.compute_base_ptr_of_batch_, CElementwiseOperation c_element_op,
arg.c0_matrix_mask_, float p_dropout,
arg.p_dropout_, std::tuple<unsigned long long, unsigned long long> seeds)
arg.seed_, {
arg.offset_); return Argument{p_a_vec,
}; p_b_vec,
p_b1_vec,
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need p_c_vec,
// to concern Gemm0's loop p_z_vec,
#if 1 p_lse_vec,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) p_acc0_biases_vec,
{ p_acc1_biases_vec,
ave_time = launch_kernel(integral_constant<bool, true>{}); problem_desc_vec,
} a_element_op,
else b_element_op,
{ acc_element_op,
ave_time = launch_kernel(integral_constant<bool, false>{}); b1_element_op,
} c_element_op,
#endif p_dropout,
return ave_time; seeds};
} }
// polymorphic static auto MakeInvoker() { return Invoker{}; }
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override // polymorphic
{ std::unique_ptr<BaseArgument>
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); MakeArgumentPointer(std::vector<const void*> p_a_vec,
} std::vector<const void*> p_b_vec,
}; std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
static constexpr bool IsValidCompilationParameter() std::vector<void*> p_z_vec,
{ std::vector<void*> p_lse_vec,
// TODO: properly implement this check std::vector<std::vector<const void*>> p_acc0_biases_vec,
return true; std::vector<std::vector<const void*>> p_acc1_biases_vec,
} std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
static bool IsSupportedArgument(const Argument& arg) BElementwiseOperation b_element_op,
{ AccElementwiseOperation acc_element_op,
#if 0 B1ElementwiseOperation b1_element_op,
arg.Print(); CElementwiseOperation c_element_op,
#endif float p_dropout,
std::tuple<unsigned long long, unsigned long long> seeds) override
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) {
{ return std::make_unique<Argument>(p_a_vec,
return false; p_b_vec,
} p_b1_vec,
p_c_vec,
// TODO: Check if tensor specialization & strides mismatch p_z_vec,
p_lse_vec,
// Check if C permute dimension matches GEMM + GEMM shape p_acc0_biases_vec,
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded p_acc1_biases_vec,
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); problem_desc_vec,
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); a_element_op,
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); b_element_op,
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); acc_element_op,
b1_element_op,
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) c_element_op,
{ p_dropout,
return false; seeds);
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of // polymorphic
// vector is out of bounds std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
// Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O {
const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; return std::make_unique<Invoker>(Invoker{});
const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; }
const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2];
const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; // polymorphic
std::string GetTypeString() const override
// Check scalar per vector requirement {
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; auto str = std::stringstream();
const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw;
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; // clang-format off
const auto c_extent_lowest = Gemm1NzRaw; str << "DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle"
<< "<"
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && << BlockSize << ", "
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && << MPerBlock << ", "
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && << NPerBlock << ", "
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) << KPerBlock << ", "
{ << AK1 << ", "
return false; << BK1 << ", "
} << MPerBlock << ", "
<< Gemm1NPerBlock << ", "
// Check vector load/store requirement << Gemm1KPerBlock << ", "
const auto a_stride_lowest = << B1K1 << ", "
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; << getGemmSpecializationString(GemmSpec) << ", "
const auto b_stride_lowest = << "ASpec" << getTensorSpecializationString(ASpec) << ", "
BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0]; << "B0Spec" << getTensorSpecializationString(BSpec) << ", "
const auto b1_stride_lowest = << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0]; << "CSpec" << getTensorSpecializationString(CSpec) << ", "
const auto c_stride_lowest = << getMaskingSpecializationString(MaskingSpec) << ">";
arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous // clang-format on
if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || return str.str();
c_stride_lowest == 1)) }
{
return false; size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
} {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, }
arg.b_grid_desc_bk0_n_bk1_, };
arg.b1_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_m_o_, } // namespace device
arg.block_2_ctile_map_); } // namespace tensor_operation
} } // namespace ck
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const DataType* p_a,
const DataType* p_b,
ZDataType* p_z,
const DataType* p_b1,
const DataType* p_c,
const LSEDataType* p_lse,
const DataType* p_ygrad_grid,
DataType* p_qgrad_grid,
DataType* p_kgrad_grid,
DataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds)
{
return Argument{p_a,
p_b,
p_z,
p_b1,
p_c,
p_lse,
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_vgrad_grid,
p_acc0_biases,
p_acc1_biases,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
// FIXME: constness
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
void* p_z,
const void* p_b1,
const void* p_c,
const void* p_lse,
const void* p_ygrad_grid,
void* p_qgrad_grid,
void* p_kgrad_grid,
void* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
const std::vector<index_t>& b_gs_ns_ks_strides,
const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op,
float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override
{
return std::make_unique<Argument>(static_cast<const DataType*>(p_a),
static_cast<const DataType*>(p_b),
static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
z_gs_ms_ns_lengths,
z_gs_ms_ns_strides,
b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths,
acc1_biases_gs_ms_gemm1ns_strides,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
p_drop,
seeds);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() // override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Train_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
<< "B0Spec" << getTensorSpecializationString(BSpec) << ", "
<< "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
<< "CSpec" << getTensorSpecializationString(CSpec) << ", "
<< getMaskingSpecializationString(MaskingSpec) << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -15,6 +16,7 @@ ...@@ -15,6 +16,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" #include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck { namespace ck {
...@@ -30,6 +32,7 @@ template <typename DataType, ...@@ -30,6 +32,7 @@ template <typename DataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1, typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1, typename KGridDesc_K0_N_K1,
typename ZGridDesc_M_N,
typename VGridDesc_N0_O_N1, typename VGridDesc_N0_O_N1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename LSEGridDesc_M, typename LSEGridDesc_M,
...@@ -80,8 +83,23 @@ template <typename DataType, ...@@ -80,8 +83,23 @@ template <typename DataType,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
template <typename T>
struct TypeMap
{
using type = T;
};
#if defined(__gfx90a__)
template <>
struct TypeMap<ck::half_t>
{
using type = ck::bhalf_t;
};
#endif
using LDSDataType = typename TypeMap<DataType>::type;
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -93,7 +111,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -93,7 +111,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
...@@ -113,6 +134,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -113,6 +134,65 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{
constexpr auto mfma = MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -347,6 +427,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -347,6 +427,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcr) // S / dP Gemm (type 1 rcr)
struct Gemm0 struct Gemm0
{ {
...@@ -388,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -388,7 +471,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -413,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -413,7 +496,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -428,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -428,13 +511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
true, // DstResetCoord true, // DstResetCoord
NumGemmKPrefetchStage>; NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max( static constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1, BK1),
MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output // Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -496,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -496,7 +580,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc, FloatGemmAcc,
DataType, LDSDataType,
decltype(a_src_thread_desc_k0_m_k1), decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -515,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -515,7 +599,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder, B1BlockTransferSrcAccessOrder,
...@@ -546,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -546,11 +630,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack = static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2< using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1), decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -566,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -566,7 +650,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
GemmKPack, GemmKPack,
true, // TransposeC true, // TransposeC
GemmKPack, // AMmaKStride GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmKPack, false>{} GemmKPack * XdlopsGemm<LDSDataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>; .K0PerXdlops /* BMmaKStride */>;
}; };
...@@ -598,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -598,7 +682,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl; static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack = static constexpr index_t GemmMPack =
math::max(math::lcm(A_M1, B_M1), math::max(math::lcm(A_M1, B_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<LDSDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>; using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using BThreadClusterLengths = using BThreadClusterLengths =
...@@ -720,12 +804,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -720,12 +804,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1, typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1,
false>; false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, LDSDataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, ElementwiseOp,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At( Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1), Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
...@@ -752,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -752,7 +837,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename Gemm2Params_N_O_M::BThreadClusterLengths, typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType, DataType,
DataType, LDSDataType,
GridDesc_M0_O_M1, GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
...@@ -769,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -769,7 +854,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, LDSDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_m0_n_m1), decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
...@@ -836,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -836,7 +921,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(FloatGemmAcc);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -848,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -848,7 +933,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, ""); static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr, using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType, FloatGemmAcc,
ThreadSliceLength_M * ThreadSliceLength_O, ThreadSliceLength_M * ThreadSliceLength_O,
true>; true>;
...@@ -1010,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1010,7 +1095,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr auto b2_block_desc_m0_o_m1 = static constexpr auto b2_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>(); GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{}; static constexpr auto max_lds_align = Number<16 / sizeof(LDSDataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -1046,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1046,13 +1131,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
{ {
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) * SharedMemTrait::b_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned + const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) * SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(DataType); sizeof(LDSDataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
...@@ -1074,6 +1159,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1074,6 +1159,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
...@@ -1089,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1089,6 +1175,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const CElementwiseOperation& c_element_op, const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1, const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1, const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1, const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock& const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock, y_grid_desc_mblock_mperblock_oblock_operblock,
...@@ -1096,8 +1184,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1096,8 +1184,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o, const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask) const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
ck::philox& ph)
{ {
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize()); p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -1147,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1147,11 +1240,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// Gemm0: LDS allocation for A and B: be careful of alignment // Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize()); Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline // Gemm0: gridwise GEMM pipeline
...@@ -1243,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1243,11 +1336,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>; decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and LDS allocation for B // Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>( auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, LDSDataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize()); Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize()); Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors // dQ: transform input and output tensor descriptors
...@@ -1331,6 +1424,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1331,6 +1424,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
decltype(thread_cluster_desc_m_n), decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{}; decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m); MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
...@@ -1360,6 +1456,75 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1360,6 +1456,75 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
// //
// set up dV / dK Gemm (type 3 crr) // set up dV / dK Gemm (type 3 crr)
// //
...@@ -1367,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1367,11 +1532,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// Gemm2: LDS allocation for A and B: be careful of alignment // Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a2_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize()); Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b2_block_space_offset, static_cast<LDSDataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize()); Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
// dV: transform input and output tensor descriptors // dV: transform input and output tensor descriptors
...@@ -1379,10 +1544,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1379,10 +1544,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o); Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy // dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{ auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds =
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
tensor_operation::element_wise::PassThrough{}}; Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy // dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy = auto vgrad_gemm_tile_ygrad_blockwise_copy =
...@@ -1407,11 +1573,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1407,11 +1573,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
make_multi_index( make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0); I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto vgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<decltype( auto vgrad_thread_copy_vgpr_to_global =
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4)>( typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4, tensor_operation::element_wise::Scale>(
tensor_operation::element_wise::PassThrough{}); vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::Scale{rp_dropout});
// dK: transform input and output tensor descriptors // dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 = const auto q_grid_desc_m0_k_m1 =
...@@ -1422,10 +1590,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1422,10 +1590,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k); Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
// dK: A matrix VGPR-to-LDS blockwise copy // dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds = typename Gemm2::ABlockwiseCopy{ auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(), Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
tensor_operation::element_wise::PassThrough{}}; Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix global-to-LDS blockwise copy // dK: B matrix global-to-LDS blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy = auto kgrad_gemm_tile_q_blockwise_copy =
...@@ -1487,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1487,7 +1656,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, DataType,
DataType, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()), decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
...@@ -1496,8 +1665,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1496,8 +1665,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector 1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */, true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock, false /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx); y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{}; auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
...@@ -1574,7 +1743,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1574,7 +1743,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) { static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM; const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd( y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM]); idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM] * p_dropout); // p_dropoutD1
}); });
block_sync_lds(); block_sync_lds();
...@@ -1595,6 +1764,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1595,6 +1764,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock; const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ // Initialize dQ
qgrad_thread_buf.Clear(); qgrad_thread_buf.Clear();
...@@ -1675,14 +1846,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1675,14 +1846,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
} }
else else
{ {
s_element_op(s_slash_p_thread_buf(i), s_slash_p_thread_buf[i]); s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
} }
}); });
} }
else else
{ {
static_for<0, s_slash_p_thread_buf.Size(), 1>{}( static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_element_op(acc_thread_buf(i), s_slash_p_thread_buf[i]); }); [&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
...@@ -1691,6 +1862,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1691,6 +1862,28 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
// scaling is already performed in the preceding statements with s_element_op // scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf); blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
block_sync_lds(); // wait for gemm1 LDS read block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
...@@ -1701,7 +1894,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1701,7 +1894,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
""); "");
// TODO: tune gemm2 pipeline // TODO: tune gemm2 pipeline
// dV = P^T * dY // dV = P_drop^T * dY
v_slash_k_grad_thread_buf.Clear(); v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B // load VGrad Gemm B
...@@ -1781,8 +1974,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1781,8 +1974,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
constexpr auto m = constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0]; pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout // dS and P has same thread buf layout
sgrad_thread_buf(i) = s_slash_p_thread_buf[i] * if(s_slash_p_thread_buf[i] >= 0)
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]); {
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
}); });
// gemm dQ // gemm dQ
...@@ -1922,6 +2124,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1922,6 +2124,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow( kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle dQ and write // shuffle dQ and write
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_dropout.hpp"
namespace ck {
template <typename DataType,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatLSE,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename SElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename QGridDesc_K0_M_K1,
typename KGridDesc_K0_N_K1,
typename ZGridDesc_M_N,
typename VGridDesc_N0_O_N1,
typename CGridDesc_M_N,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t B1K1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
index_t BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1ThreadTransferSrcResetCoordinateAfterRun,
index_t B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN,
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I9 = Number<9>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...>
// Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n)
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N)
{
constexpr auto mfma = MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
template <typename AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4>
__host__ __device__ static constexpr auto GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
const AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4& acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4)
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a_src_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
const auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
const auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
const auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
const auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
const auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
const auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
const auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
const auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
return transform_tensor_descriptor(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)),
make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)),
make_pass_through_transform(n4)),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1K1),
make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetA2BlockDescriptor_M0_N_M1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::A_M0>{},
Number<Gemm2Param::Free0_N>{},
Number<Gemm2Param::A_M1>{}),
make_tuple(Number<Gemm2Param::Free0_N + Gemm2Param::A_LdsPad>{} *
Number<Gemm2Param::A_M1>{},
Number<Gemm2Param::A_M1>{},
I1));
}
template <typename Gemm2Param>
__host__ __device__ static constexpr auto GetB2BlockDescriptor_M0_O_M1()
{
return make_naive_tensor_descriptor(
make_tuple(Number<Gemm2Param::B_M0>{},
Number<Gemm2Param::Free1_O>{},
Number<Gemm2Param::B_M1>{}),
make_tuple(Number<Gemm2Param::Free1_O + Gemm2Param::B_LdsPad>{} *
Number<Gemm2Param::B_M1>{},
Number<Gemm2Param::B_M1>{},
I1));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_n0_o_n1.GetLength(I1);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K)
{
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
return false;
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0))
{
return false;
}
// check gemm0 gridwise gemm pipeline
const auto num_gemm0_k_loop = K / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
{
return false;
}
// check gemm1 gridwise gemm pipeline
if(!(NPerBlock % Gemm1KPerBlock == 0))
{
return false;
}
const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
{
return false;
}
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / Gemm1NPerBlock;
const auto y_grid_desc_mblock_mperblock_oblock_operblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return y_grid_desc_mblock_mperblock_oblock_operblock;
}
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(
make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{}));
return lse_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, Gemm1NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
// S / dP Gemm (type 1 rcr)
struct Gemm0
{
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
BBlockDesc_BK0_N_BK1{});
}
template <typename GridDesc_K0_M_K1>
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
true, // SrcResetCoord
true, // DstResetCoord
NumGemmKPrefetchStage>;
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
// Blockwise gemm with transposed XDL output
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
true>; // TransposeC
static constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
static constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
};
// Y / dQ Gemm (type 2 rrr)
template <typename ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4,
typename ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4>
struct Gemm1
{
private:
static constexpr auto m0 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I0);
static constexpr auto n0 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I1);
static constexpr auto m1 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I2);
static constexpr auto n1 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I3);
static constexpr auto m2 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I4);
static constexpr auto n2 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I5);
static constexpr auto n3 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I6);
static constexpr auto n4 = ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I7);
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
static constexpr auto N3 = ASrcBlockDesc_M0_N0_M1_N1_M2_N2_N3_N4{}.GetLength(I6);
public:
static constexpr auto AThreadSliceLength_K0 = Number<Gemm1KPerBlock / n4 / N3>{};
static constexpr auto AThreadSliceLength_M = Number<m0 * m1 * m2>{};
static constexpr auto AThreadSliceLength_K1 = Number<n4>{};
// A source matrix layout in AccVGPR
static constexpr auto a_src_thread_desc_k0_m_k1 =
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4{});
// A matrix in VGPR memory, dst of AccVGPR-to-VGPR copy
static constexpr auto a_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
make_tuple(AThreadSliceLength_K0, AThreadSliceLength_M, AThreadSliceLength_K1));
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto ASrcScalarPerVector = n4;
using AThreadSliceLengths_K0_M_K1 = decltype(a_thread_desc_k0_m_k1.GetLengths());
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_k0_m_k1),
decltype(a_thread_desc_k0_m_k1),
tensor_operation::element_wise::PassThrough,
AThreadSliceLengths_K0_M_K1,
Sequence<1, 0, 2>,
2,
ASrcScalarPerVector>;
template <typename GridDesc_K0_N_K1>
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1),
B1BlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
B1BlockTransferSrcVectorDim,
2,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
1,
1,
B1ThreadTransferSrcResetCoordinateAfterRun,
true, // DstResetCoord
NumGemmKPrefetchStage>;
// for a_block_slice_copy_step to be able to address static buffers, it MUST be a
// tuple-based container as well as containing ONLY integral constants
static constexpr auto a_block_slice_copy_step = make_tuple(AThreadSliceLength_K0, I0, I0);
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
static constexpr index_t GemmKPack =
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.group_size;
using BlockwiseGemm = BlockwiseGemmXdlops_v2<
BlockSize,
DataType,
FloatGemmAcc,
decltype(a_thread_desc_k0_m_k1),
decltype(b_block_desc_bk0_n_bk1),
decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a_thread_desc_k0_m_k1)),
decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
MPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
Gemm1NXdlPerWave,
GemmKPack,
true, // TransposeC
GemmKPack, // AMmaKStride
GemmKPack * XdlopsGemm<DataType, MPerXdl, NPerXdl, GemmKPack, false>{}
.K0PerXdlops /* BMmaKStride */>;
};
// dV / dK Gemm (type 3 crr)
// Describes tuning parameter for C2_n_o = A2_n_m * B2_m_o
template <index_t Sum_M_ = MPerXdl * 2>
struct Gemm2Params_N_O_M_
{
static constexpr index_t Free0_N = NPerBlock;
static constexpr index_t Free1_O = Gemm1NPerBlock;
static constexpr index_t Sum_M = Sum_M_;
static constexpr index_t A_M1 = 8; // P will be row-major
static constexpr index_t A_M0 = Sum_M / A_M1;
static constexpr index_t A_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static constexpr index_t B_M1 = 2; // dY assumed row-major, typically =2 for fp16
static constexpr index_t B_M0 = Sum_M / B_M1;
static constexpr index_t B_LdsPad = 0; // how many multiples of M1 per N * M1 elements
static_assert(Sum_M % MPerXdl == 0, "");
static constexpr index_t BSrcVectorDim = 1; // Free1_O dimension
static constexpr index_t BSrcScalarPerVector = 4;
static constexpr index_t GemmNWave = 2;
static constexpr index_t GemmOWave = BlockSize / get_warp_size() / GemmNWave;
static constexpr index_t GemmNRepeat = Free0_N / GemmNWave / MPerXdl;
static constexpr index_t GemmORepeat = Free1_O / GemmOWave / NPerXdl;
static constexpr index_t GemmMPack =
math::max(math::lcm(A_M1, B_M1),
MfmaSelector<DataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using BBlockSliceLengths = Sequence<B_M0, Free1_O, B_M1>;
using BThreadClusterLengths =
Sequence<BlockSize / (Free1_O / BSrcScalarPerVector), Free1_O / BSrcScalarPerVector, 1>;
using BThreadClusterArrangeOrder = Sequence<0, 2, 1>;
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1_M2_N2()
{
// perform manual unmerge: m -> m_repeat, m_waves, m_per_xdl
constexpr index_t m = Gemm2Params_N_O_M::Sum_M - 1;
constexpr index_t m2 = m % MPerXdl;
constexpr index_t m1 = m / MPerXdl % Gemm0MWaves;
constexpr index_t m0 = m / MPerXdl / Gemm0MWaves % MXdlPerWave;
// perform manual unmerge: n -> n_repeat, n_waves, n_per_xdl
constexpr index_t n = Gemm2Params_N_O_M::Free0_N - 1;
constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
// 1d idx (256 - 1) -> 3d idx 1, 3, 31 -> 3d dim 2 x 4 x 32
return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
}
__host__ __device__ static constexpr auto GetABlockSliceLengths_M0_N0_M1_N1()
{
return generate_sequence_v2(
[](auto I) { return GetABlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
Number<4>{});
}
using ABlockSliceLengths_M0_N0_M1_N1 = decltype(GetABlockSliceLengths_M0_N0_M1_N1());
};
using Gemm2Params_N_O_M = Gemm2Params_N_O_M_<>; // tune later
// dV / dK Gemm (type 3 crr)
template <typename Gemm2Params_N_O_M, typename ASrcBlockwiseGemm>
struct Gemm2
{
private:
static constexpr auto a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
ASrcBlockwiseGemm::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
static constexpr auto M0 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); // repeat
static constexpr auto N0 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
static constexpr auto M1 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); // wave
static constexpr auto N1 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
static constexpr auto M2 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); // xdl
static constexpr auto N2 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
static constexpr auto N3 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
static constexpr auto N4 = a_src_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
public:
// A source matrix layout in VGPR, src of VGPR-to-LDS copy
static constexpr auto a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
ASrcBlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// A matrix in LDS memory, dst of blockwise copy
static constexpr auto a_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
// B matrix in LDS memory, dst of blockwise copy
static constexpr auto b_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
__host__ __device__ static constexpr auto MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4()
{
const auto M0_ = a_block_desc_m0_n_m1.GetLength(I0);
const auto N_ = a_block_desc_m0_n_m1.GetLength(I1);
const auto M1_ = a_block_desc_m0_n_m1.GetLength(I2);
const auto a_block_desc_m_n = transform_tensor_descriptor(
a_block_desc_m0_n_m1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(M0_, M1_)),
make_pass_through_transform(N_)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor(
a_block_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(I1, M1, M2)),
make_unmerge_transform(make_tuple(I1, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
}
// Note: we will perform sub-workgroup VGPR-to-LDS copy to save LDS space, therefore the
// destination coordinate can overlap between wavefronts in a workgroup as seen in the mod
// operation before returning the values
__host__ __device__ static auto MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4()
{
const auto a_thread_origin_on_block_idx =
ASrcBlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0);
constexpr auto c_block_slice_lengths_m0_n0_m1_n1 =
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1{}; // mrepeat, nrepeat,
// mwaves, nwaves,
return make_tuple(
a_thread_origin_on_block_idx[I0], // mrepeat
a_thread_origin_on_block_idx[I1], // nrepeat
a_thread_origin_on_block_idx[I2] % c_block_slice_lengths_m0_n0_m1_n1[I2], // mwave
a_thread_origin_on_block_idx[I3] % c_block_slice_lengths_m0_n0_m1_n1[I3], // nwave
a_thread_origin_on_block_idx[I4], // xdlops
a_thread_origin_on_block_idx[I5],
a_thread_origin_on_block_idx[I6],
a_thread_origin_on_block_idx[I7]);
}
static constexpr auto a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
MakeABlockDesc_M0_N0_M1_N1_M2_N2_N3_N4();
using ASrcBlockSliceWindowIterator =
SpaceFillingCurve<Sequence<M0, N0, M1, N1>,
Sequence<0, 1, 2, 3>,
typename Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1,
false>;
template <typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using ABlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
ElementwiseOp,
Sequence<Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(
I0), // ThreadSliceLengths
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I1),
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7, // DstVectorDim
1, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>;
template <typename GridDesc_M0_O_M1>
using BBlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
typename Gemm2Params_N_O_M::BBlockSliceLengths,
typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType,
DataType,
GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1),
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, // access order == thread order
Sequence<1, 0, 2>,
Gemm2Params_N_O_M::BSrcVectorDim,
2, // DstVectorDim
Gemm2Params_N_O_M::BSrcScalarPerVector,
Gemm2Params_N_O_M::B_M1,
1,
1,
true,
true,
1>;
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
FloatGemmAcc,
decltype(a_block_desc_m0_n_m1),
decltype(b_block_desc_m0_o_m1),
MPerXdl,
NPerXdl,
Gemm2Params_N_O_M::GemmNRepeat,
Gemm2Params_N_O_M::GemmORepeat,
Gemm2Params_N_O_M::GemmMPack,
true>; // TranspossC
static constexpr auto b_block_slice_copy_step =
make_multi_index(Gemm2Params_N_O_M::B_M0, 0, 0);
static constexpr auto c_block_slice_copy_step =
make_multi_index(Gemm2Params_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0);
static constexpr auto b_block_reset_copy_step =
make_multi_index(-MPerBlock / Gemm2Params_N_O_M::B_M1, 0, 0);
template <typename CGradDesc_N_O>
__host__ __device__ static auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(const CGradDesc_N_O& c_grid_desc_n_o)
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto c_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
c_grid_desc_n_o,
make_tuple(
make_unmerge_transform(make_tuple(I1, Gemm2Params_N_O_M::GemmNWave, MPerXdl)),
make_unmerge_transform(make_tuple(I1, Gemm2Params_N_O_M::GemmOWave, NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
const auto c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
BlockwiseGemm{}.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
c_grid_desc_n0_o0_n1_o1_n2_o2);
return c_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4;
}
static constexpr auto c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
BlockwiseGemm::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
__host__ __device__ static auto GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4()
{
return to_multi_index(BlockwiseGemm::CalculateCThreadOriginDataIndex8D(I0, I0, I0, I0));
}
template <typename CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
DataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim
2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
};
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_
{
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType);
static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
static constexpr auto ThreadSliceLength_O = Number<SrcScalarPerVector>{};
static constexpr auto ThreadSliceLength_M =
Number<BlockSliceLength_M_ * ThreadClusterLength_O / BlockSize_>{};
static_assert(ThreadClusterLength_O * ThreadSliceLength_O == BlockSliceLength_O_, "");
static_assert(ThreadClusterLength_M * ThreadSliceLength_M == BlockSliceLength_M_, "");
using SrcBufType = StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
ThreadSliceLength_M * ThreadSliceLength_O,
true>;
using DstBufType =
StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, ThreadSliceLength_M, true>;
};
using YDotYGrad_M_O = YDotYGrad_M_O_<BlockSize, MPerBlock, Gemm1NPerBlock>;
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct PGradGemmTile_M_N_O
{
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
template <typename YGradGridDesc_M0_O_M1_>
__device__ static auto
MakeYGradGridDesc_O0_M_O1(const YGradGridDesc_M0_O_M1_& ygrad_grid_desc_m0_o_m1)
{
const auto M0 = ygrad_grid_desc_m0_o_m1.GetLength(I0);
const auto O = ygrad_grid_desc_m0_o_m1.GetLength(I1);
const auto M1 = ygrad_grid_desc_m0_o_m1.GetLength(I2);
constexpr auto Y_O1 = AK1;
const auto Y_O0 = O / Y_O1;
const auto ygrad_grid_desc_o0_m_o1 = transform_tensor_descriptor(
ygrad_grid_desc_m0_o_m1,
make_tuple(make_unmerge_transform(make_tuple(Y_O0, Y_O1)),
make_merge_transform_v3_division_mod(make_tuple(M0, M1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return ygrad_grid_desc_o0_m_o1;
}
template <typename VGridDesc_N0_O_N1_>
__device__ static auto MakeVGridDesc_O0_N_O1(const VGridDesc_N0_O_N1_& v_grid_desc_n0_o_n1)
{
const auto N0 = v_grid_desc_n0_o_n1.GetLength(I0);
const auto O = v_grid_desc_n0_o_n1.GetLength(I1);
const auto N1 = v_grid_desc_n0_o_n1.GetLength(I2);
constexpr auto V_O1 = BK1;
const auto V_O0 = O / V_O1;
const auto v_grid_desc_o0_n_o1 = transform_tensor_descriptor(
v_grid_desc_n0_o_n1,
make_tuple(make_unmerge_transform(make_tuple(V_O0, V_O1)),
make_merge_transform_v3_division_mod(make_tuple(N0, N1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return v_grid_desc_o0_n_o1;
}
};
// QGrad Gemm has the same layout as Y = P * V Gemm (A in acc B row-major)
struct QGradGemmTile_M_K_N
{
template <typename QGridDesc_K0_M_K1_>
__device__ static auto MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto K1 = q_grid_desc_k0_m_k1.GetLength(I2);
const auto K = K0 * K1;
const auto MBlock = M / MPerBlock;
const auto KBlock = K / Gemm1NPerBlock; // NOTE: QGrad gemm is similar to Y gemm
const auto q_grid_desc_m_k = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_pass_through_transform(M),
make_merge_transform_v3_division_mod(make_tuple(K0, K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return transform_tensor_descriptor(
q_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(KBlock, Number<Gemm1NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
template <typename KGridDesc_K0_N_K1_>
__device__ static auto MakeKGridDesc_N0_K_N1(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
constexpr auto K_N1 = B1K1;
const auto K_N0 = N / K_N1;
const auto k_grid_desc_n0_k_n1 = transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(make_tuple(K_N0, K_N1)),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return k_grid_desc_n0_k_n1;
}
};
struct KGradGemmTile_N_K_M
{
// B position
template <typename QGridDesc_K0_M_K1_>
__device__ static auto MakeQGridDesc_M0_K_M1(const QGridDesc_K0_M_K1_& q_grid_desc_k0_m_k1)
{
const auto Q_K0 = q_grid_desc_k0_m_k1.GetLength(I0);
const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto Q_K1 = q_grid_desc_k0_m_k1.GetLength(I2);
constexpr auto Q_M1 = B1K1;
const auto Q_M0 = M / Q_M1;
const auto q_grid_desc_m0_k_m1 = transform_tensor_descriptor(
q_grid_desc_k0_m_k1,
make_tuple(make_unmerge_transform(make_tuple(Q_M0, Q_M1)),
make_merge_transform_v3_division_mod(make_tuple(Q_K0, Q_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return q_grid_desc_m0_k_m1;
}
// C position
template <typename KGridDesc_K0_N_K1_>
__device__ static auto MakeKGradGridDesc_N_K(const KGridDesc_K0_N_K1_& k_grid_desc_k0_n_k1)
{
const auto K_K0 = k_grid_desc_k0_n_k1.GetLength(I0);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K_K1 = k_grid_desc_k0_n_k1.GetLength(I2);
return transform_tensor_descriptor(
k_grid_desc_k0_n_k1,
make_tuple(make_pass_through_transform(N),
make_merge_transform_v3_division_mod(make_tuple(K_K0, K_K1))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
};
struct SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static constexpr auto a_block_desc_ak0_m_ak1 =
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
static constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto a2_block_desc_m0_n_m1 =
GetA2BlockDescriptor_M0_N_M1<Gemm2Params_N_O_M>();
static constexpr auto b2_block_desc_m0_o_m1 =
GetB2BlockDescriptor_M0_O_M1<Gemm2Params_N_O_M>();
static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned = math::integer_least_multiple(
a2_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
b2_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0;
static constexpr auto a2_block_space_offset = 0;
static constexpr auto b2_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned =
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto reduction_space_offset = 0;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
static constexpr auto c_block_space_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
SharedMemTrait::b_block_space_size_aligned) *
sizeof(DataType);
const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(DataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(DataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc);
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end,
gemm1_bytes_end,
vgrad_gemm_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
}
template <bool HasMainKBlockLoop,
typename Block2CTileMap,
typename C0MatrixMask,
typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const SElementwiseOperation& s_element_op,
const B1ElementwiseOperation& b1_element_op,
const CElementwiseOperation& c_element_op,
const QGridDesc_K0_M_K1& q_grid_desc_k0_m_k1,
const KGridDesc_K0_N_K1& k_grid_desc_k0_n_k1,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const VGridDesc_N0_O_N1& v_grid_desc_n0_o_n1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock&
y_grid_desc_mblock_mperblock_oblock_operblock,
const LSEGridDesc_M& lse_grid_desc_m,
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask,
FloatGemmAcc p_dropout,
ck::philox& ph)
{
const ushort p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
const FloatGemmAcc rp_dropout = 1.0f / p_dropout;
const auto q_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_q_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto k_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_k_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
const auto v_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_v_grid, v_grid_desc_n0_o_n1.GetElementSpaceSize());
const auto y_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_grid, y_grid_desc_mblock_mperblock_oblock_operblock.GetElementSpaceSize());
const auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
auto qgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_qgrad_grid, q_grid_desc_k0_m_k1.GetElementSpaceSize());
auto kgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_kgrad_grid, k_grid_desc_k0_n_k1.GetElementSpaceSize());
// divide block work by [M, O]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I0),
y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2))))
{
return;
}
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
//
// set up S / dP Gemm (type 1 rcr)
//
// Gemm0: LDS allocation for A and B: be careful of alignment
auto gemm0_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a_block_space_offset,
Gemm0::a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto gemm0_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
Gemm0::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const auto gemm0_gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer,
NumGemmKPrefetchStage,
LoopScheduler::Default>();
// S: A matrix blockwise copy
auto s_gemm_tile_q_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(q_grid_desc_k0_m_k1)>(
q_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: B matrix blockwise copy
auto s_gemm_tile_k_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(k_grid_desc_k0_n_k1)>(
k_grid_desc_k0_n_k1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
b_element_op,
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// S: blockwise gemm
auto s_blockwise_gemm = typename Gemm0::BlockwiseGemm{}; // TransposeC
auto s_slash_p_thread_buf = s_blockwise_gemm.GetCThreadBuffer();
const auto s_gemm_tile_a_block_reset_copy_step =
make_multi_index(-q_grid_desc_k0_m_k1.GetLength(I0), 0, 0);
const auto s_gemm_tile_b_block_reset_copy_step =
make_multi_index(-k_grid_desc_k0_n_k1.GetLength(I0), NPerBlock, 0);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2)) / KPerBlock);
// dP: transform input and output tensor descriptors
const auto ygrad_grid_desc_o0_m_o1 =
PGradGemmTile_M_N_O::MakeYGradGridDesc_O0_M_O1(ygrad_grid_desc_m0_o_m1);
const auto v_grid_desc_o0_n_o1 =
PGradGemmTile_M_N_O::MakeVGridDesc_O0_N_O1(v_grid_desc_n0_o_n1);
// dP: A matrix blockwise copy
auto pgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm0::template ABlockwiseCopy<decltype(ygrad_grid_desc_o0_m_o1)>(
ygrad_grid_desc_o0_m_o1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
tensor_operation::element_wise::PassThrough{},
Gemm0::a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: B matrix blockwise copy
auto pgrad_gemm_tile_v_blockwise_copy =
typename Gemm0::template BBlockwiseCopy<decltype(v_grid_desc_o0_n_o1)>(
v_grid_desc_o0_n_o1,
make_multi_index(0, 0, 0), // will loop over GemmN dimension
tensor_operation::element_wise::PassThrough{},
Gemm0::b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto pgrad_blockwise_gemm = typename Gemm0::BlockwiseGemm{};
auto pgrad_thread_buf = pgrad_blockwise_gemm.GetCThreadBuffer();
const auto pgrad_gemm_tile_ygrad_block_reset_copy_step =
make_multi_index(-ygrad_grid_desc_o0_m_o1.GetLength(I0), 0, 0);
const auto pgrad_gemm_tile_v_block_reset_copy_step =
make_multi_index(-v_grid_desc_o0_n_o1.GetLength(I0), NPerBlock, 0);
const index_t num_o_block_main_loop = __builtin_amdgcn_readfirstlane(
(ygrad_grid_desc_o0_m_o1.GetLength(I0) * ygrad_grid_desc_o0_m_o1.GetLength(I2)) /
KPerBlock);
//
// set up Y / dQ Gemm (type 2 rrr)
//
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using Gemm1 =
Gemm1<decltype(s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()),
decltype(s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4())>;
// Gemm1: VGPR allocation for A and LDS allocation for B
auto gemm1_a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, DataType>(
Gemm1::a_thread_desc_k0_m_k1.GetElementSpaceSize());
auto gemm1_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
Gemm1::b_block_desc_bk0_n_bk1.GetElementSpaceSize());
// dQ: transform input and output tensor descriptors
const auto k_grid_desc_n0_k_n1 =
QGradGemmTile_M_K_N::MakeKGridDesc_N0_K_N1(k_grid_desc_k0_n_k1);
auto qgrad_grid_desc_mblock_mperblock_kblock_kperblock =
QGradGemmTile_M_K_N::MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock(
q_grid_desc_k0_m_k1);
// dQ: A matrix blockwise copy
auto qgrad_gemm_tile_sgrad_blockwise_copy =
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ: B matrix blockwise copy
auto qgrad_gemm_tile_k_blockwise_copy =
typename Gemm1::template BBlockwiseCopy<decltype(k_grid_desc_n0_k_n1)>(
k_grid_desc_n0_k_n1,
make_multi_index(0, o_block_data_idx_on_grid, 0),
b1_element_op,
Gemm1::b_block_desc_bk0_n_bk1, // there n actually is k, k is N, so name can be
// b_block_desc_bn0_k_bn1
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dQ: blockwise gemm
auto qgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0)}; // A_origin
auto qgrad_thread_buf = qgrad_blockwise_gemm.GetCThreadBuffer();
//
// Blockwise softmax
//
// get acc0 8D thread cluster
constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
// get acc0 thread map
constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)),
make_pass_through_transform(I1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto threadid_to_m_n_thread_cluster_adaptor =
chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
// get acc0 2D thread cluster & 2D thread slice
constexpr auto thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto m0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
constexpr auto n0 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
constexpr auto m1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
constexpr auto n1 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
constexpr auto m2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto n2 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto n3 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto n4 = thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
constexpr auto thread_slice_desc_m_n =
make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
FloatGemmAcc,
decltype(threadid_to_m_n_thread_cluster_adaptor),
decltype(thread_cluster_desc_m_n),
decltype(thread_slice_desc_m_n)>{};
auto blockwise_dropout = BlockwiseDropout<FloatGemmAcc, decltype(thread_slice_desc_m_n)>{
p_dropout_in_16bits, rp_dropout};
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
constexpr auto lse_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
auto acc0_thread_origin = s_blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mblock_mrepeat_mwave_mperxdl),
decltype(lse_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
//
// set up dV / dK Gemm (type 3 crr)
//
using Gemm2 = Gemm2<Gemm2Params_N_O_M, decltype(s_blockwise_gemm)>;
// Gemm2: LDS allocation for A and B: be careful of alignment
auto gemm2_a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::a2_block_space_offset,
Gemm2::a_block_desc_m0_n_m1.GetElementSpaceSize());
auto gemm2_b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::b2_block_space_offset,
Gemm2::b_block_desc_m0_o_m1.GetElementSpaceSize());
// dV: transform input and output tensor descriptors
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(vgrad_grid_desc_n_o);
// dV: A matrix VGPR-to-LDS blockwise copy
auto vgrad_gemm_tile_p_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::Relu>{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::Relu{}}; // relu(P-dropped)
// dV: B matrix global-to-LDS blockwise copy
auto vgrad_gemm_tile_ygrad_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(ygrad_grid_desc_m0_o_m1)>(
ygrad_grid_desc_m0_o_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dV: blockwise gemm
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
auto v_slash_k_grad_thread_buf = v_slash_k_grad_blockwise_gemm.GetCThreadBuffer();
// dV: C VGPR-to-global copy
const auto vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto vgrad_thread_copy_vgpr_to_global =
typename Gemm2::template CBlockwiseCopy<decltype(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::Scale>(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
tensor_operation::element_wise::Scale{rp_dropout});
// dK: transform input and output tensor descriptors
const auto q_grid_desc_m0_k_m1 =
KGradGemmTile_N_K_M::MakeQGridDesc_M0_K_M1(q_grid_desc_k0_m_k1);
const auto kgrad_grid_desc_n_k =
KGradGemmTile_N_K_M::MakeKGradGridDesc_N_K(k_grid_desc_k0_n_k1);
const auto kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4(kgrad_grid_desc_n_k);
// dK: A matrix VGPR-to-LDS blockwise copy
auto kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds =
typename Gemm2::template ABlockwiseCopy<tensor_operation::element_wise::PassThrough>{
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
Gemm2::MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4(),
tensor_operation::element_wise::PassThrough{}};
// dK: B matrix global-to-LDS blockwise copy
auto kgrad_gemm_tile_q_blockwise_copy =
typename Gemm2::template BBlockwiseCopy<decltype(q_grid_desc_m0_k_m1)>(
q_grid_desc_m0_k_m1,
make_multi_index(m_block_data_idx_on_grid / Gemm2Params_N_O_M::B_M1,
o_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{},
Gemm2::b_block_desc_m0_o_m1,
make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{});
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const auto kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4 =
Gemm2::GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4() +
make_multi_index(
I0, block_work_idx[I1] * Gemm2Params_N_O_M::GemmORepeat, I0, I0, I0, I0, I0, I0);
auto kgrad_thread_copy_vgpr_to_global = typename Gemm2::template CBlockwiseCopy<
decltype(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(s_element_op)>(kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4,
s_element_op);
//
// set up Y dot dY
//
// m0, n0 are m/n repeat per wave
// m1, n1 are number of waves
constexpr auto p_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto P_M0 = p_block_lengths[I0]; // repeats
constexpr auto P_M1 = p_block_lengths[I2]; // waves
constexpr auto P_M2 = p_block_lengths[I4]; // xdl
constexpr auto y_thread_desc_m0_m1_o0_o1 = make_naive_tensor_descriptor_packed(make_tuple(
I1, YDotYGrad_M_O::ThreadSliceLength_M, I1, YDotYGrad_M_O::ThreadSliceLength_O));
constexpr auto y_thread_cluster_desc =
make_cluster_descriptor(Sequence<I1,
YDotYGrad_M_O::ThreadClusterLength_M,
I1,
YDotYGrad_M_O::ThreadClusterLength_O>{},
Sequence<0, 1, 2, 3>{});
const auto y_thread_cluster_idx =
y_thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto y_thread_data_on_block_idx =
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx =
make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx;
// performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1),
decltype(y_thread_desc_m0_m1_o0_o1.GetLengths()),
Sequence<0, 1, 2, 3>,
3, // SrcVectorDim
YDotYGrad_M_O::SrcScalarPerVector, // SrcScalarPerVector
1, // SrcScalarStrideInVector
true /* ResetCoordAfterRun */,
true /* InvalidElementAsNaN */>(y_grid_desc_mblock_mperblock_oblock_operblock,
y_thread_data_on_grid_idx);
auto y_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto ygrad_thread_buf = typename YDotYGrad_M_O::SrcBufType{};
auto y_dot_ygrad_thread_accum_buf = typename YDotYGrad_M_O::DstBufType{};
auto y_dot_ygrad_block_accum_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatGemmAcc*>(p_shared), MPerBlock);
constexpr auto y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor(make_tuple(I1, P_M0, P_M1, P_M2),
make_tuple(P_M0 * P_M1 * P_M2, P_M1 * P_M2, P_M2, I1));
constexpr auto y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl =
lse_thread_desc_mblock_mrepeat_mwave_mperxdl; // reuse LSE thread descriptor because
// per-thread LSE data and y_dot_ygrad is
// tiled the same way
auto y_dot_ygrad_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
FloatGemmAcc,
FloatGemmAcc,
decltype(y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl),
decltype(y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(I0, // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
auto y_dot_ygrad_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatGemmAcc>(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
//
// calculate Y dot dY
//
// clear accum buffers
y_dot_ygrad_thread_accum_buf.Clear();
y_dot_ygrad_block_accum_buf.Clear();
index_t oblock_idx = 0;
do
{
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
y_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
y_thread_buf);
yygrad_threadwise_copy.Run(y_grid_desc_mblock_mperblock_oblock_operblock,
ygrad_grid_buf,
y_thread_desc_m0_m1_o0_o1,
make_tuple(I0, I0, I0, I0),
ygrad_thread_buf);
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
static_for<0, YDotYGrad_M_O::ThreadSliceLength_O, 1>{}([&](auto iO) {
constexpr auto offset =
y_thread_desc_m0_m1_o0_o1.CalculateOffset(make_multi_index(I0, iM, I0, iO));
y_dot_ygrad_thread_accum_buf(iM) +=
y_thread_buf[Number<offset>{}] * ygrad_thread_buf[Number<offset>{}];
});
});
yygrad_threadwise_copy.MoveSrcSliceWindow(y_grid_desc_mblock_mperblock_oblock_operblock,
make_multi_index(0, 0, 1, 0));
oblock_idx++;
} while(oblock_idx < y_grid_desc_mblock_mperblock_oblock_operblock.GetLength(I2));
// blockwise reduction using atomic_add
block_sync_lds();
static_for<0, YDotYGrad_M_O::ThreadSliceLength_M, 1>{}([&](auto iM) {
const auto idx_on_block = y_thread_data_on_block_idx[I1] + iM;
y_dot_ygrad_block_accum_buf.AtomicAdd(
idx_on_block, true, y_dot_ygrad_thread_accum_buf[iM] * p_dropout); // p_dropoutD1
});
block_sync_lds();
// distribute y_dot_ygrad to threads; LDS accum buffer can be safely reused after barrier
y_dot_ygrad_thread_copy_lds_to_vgpr.Run(
y_dot_ygrad_block_desc_mblock_mrepeat_mwave_mperxdl,
y_dot_ygrad_block_accum_buf,
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
y_dot_ygrad_thread_buf);
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
const index_t num_gemm1_k_block_outer_loop = k_grid_desc_k0_n_k1.GetLength(I1) / NPerBlock;
constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
const index_t K = k_grid_desc_k0_n_k1.GetLength(I0) * k_grid_desc_k0_n_k1.GetLength(I2);
const float scalar = 1.0f / std::sqrt(K);
// Initialize dQ
qgrad_thread_buf.Clear();
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{
continue;
}
// S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
q_grid_desc_k0_m_k1,
Gemm0::a_block_desc_ak0_m_ak1,
s_gemm_tile_q_blockwise_copy,
q_grid_buf,
gemm0_a_block_buf,
Gemm0::a_block_slice_copy_step,
k_grid_desc_k0_n_k1,
Gemm0::b_block_desc_bk0_n_bk1,
s_gemm_tile_k_blockwise_copy,
k_grid_buf,
gemm0_b_block_buf,
Gemm0::b_block_slice_copy_step,
s_blockwise_gemm,
s_slash_p_thread_buf,
num_k_block_main_loop);
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
// 8d thread_desc in thread scope
constexpr auto c_thread_lengths =
s_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
// 8d block_desc in block scope
constexpr auto c_block_lengths =
s_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
constexpr auto M0 = c_block_lengths[I0];
constexpr auto N0 = c_block_lengths[I1];
constexpr auto M1 = c_block_lengths[I2];
constexpr auto N1 = c_block_lengths[I3];
constexpr auto M2 = c_block_lengths[I4];
constexpr auto N2 = c_block_lengths[I5];
constexpr auto N3 = c_block_lengths[I6];
constexpr auto N4 = c_block_lengths[I7];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
auto m_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
auto n_local =
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
auto m_global = m_local + m_block_data_idx_on_grid;
auto n_global = n_local + n_block_data_idx_on_grid;
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
s_slash_p_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
}
else
{
s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i];
}
});
}
else
{
static_for<0, s_slash_p_thread_buf.Size(), 1>{}(
[&](auto i) { s_slash_p_thread_buf(i) = scalar * s_slash_p_thread_buf[i]; });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// P_i: = softmax(scalar * S_i:)
// scaling is already performed in the preceding statements with s_element_op
blockwise_softmax.RunWithPreCalcStats(s_slash_p_thread_buf, lse_thread_buf);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf),
decltype(z_tenor_buffer),
true>(
s_slash_p_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(s_slash_p_thread_buf), true>(
s_slash_p_thread_buf, ph);
}
block_sync_lds(); // wait for gemm1 LDS read
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
"");
// TODO: tune gemm2 pipeline
// dV = P_drop^T * dY
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
// load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
ygrad_grid_buf);
// load VGrad Gemm A
const auto p_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range = make_tuple(
p_slice_idx[I2],
p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range = make_tuple(
p_slice_idx[I3],
p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
s_slash_p_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
vgrad_grid_buf);
// gemm dP
block_sync_lds();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
ygrad_grid_desc_o0_m_o1,
Gemm0::a_block_desc_ak0_m_ak1, // reuse
pgrad_gemm_tile_ygrad_blockwise_copy,
ygrad_grid_buf,
gemm0_a_block_buf, // reuse
Gemm0::a_block_slice_copy_step, // reuse
v_grid_desc_o0_n_o1,
Gemm0::b_block_desc_bk0_n_bk1, // reuse
pgrad_gemm_tile_v_blockwise_copy,
v_grid_buf,
gemm0_b_block_buf, // reuse
Gemm0::b_block_slice_copy_step, // reuse
pgrad_blockwise_gemm,
pgrad_thread_buf,
num_o_block_main_loop);
// dS = P * (dP - Y_dot_dY)
auto& sgrad_thread_buf = pgrad_thread_buf;
constexpr auto pgrad_thread_tile_iterator =
pgrad_blockwise_gemm.MakeCThreadTileIterator();
constexpr auto pgrad_thread_idx_to_m_n_adaptor =
pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
constexpr auto m =
pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// dS and P has same thread buf layout
if(s_slash_p_thread_buf[i] >= 0)
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] *
(pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
}
else
{
sgrad_thread_buf(i) =
s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
}
});
// gemm dQ
// dQ = scalar * dS * K
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm1::b_block_slice_copy_step);
block_sync_lds(); // wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
// main body
if constexpr(num_gemm1_k_block_inner_loop > 1)
{
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * i,
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
block_sync_lds();
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_n0_k_n1, Gemm1::b_block_slice_copy_step);
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm1::b_block_desc_bk0_n_bk1,
gemm1_b_block_buf);
});
}
// tail
{
qgrad_gemm_tile_sgrad_blockwise_copy.Run(
Gemm1::a_src_thread_desc_k0_m_k1,
Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
sgrad_thread_buf,
Gemm1::a_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
gemm1_a_thread_buf);
block_sync_lds();
qgrad_blockwise_gemm.Run(
gemm1_a_thread_buf, gemm1_b_block_buf, qgrad_thread_buf);
}
} // end gemm dQ
// dK = scalar * dS^T * dQ
v_slash_k_grad_thread_buf.Clear();
static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// load KGrad Gemm A
const auto sgrad_slice_idx =
Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
constexpr auto mwave_range =
make_tuple(sgrad_slice_idx[I2],
sgrad_slice_idx[I2] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
constexpr auto nwave_range =
make_tuple(sgrad_slice_idx[I3],
sgrad_slice_idx[I3] +
Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(
sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
sgrad_thread_buf,
Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
gemm2_a_block_buf);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
gemm2_b_block_buf);
block_sync_lds(); // sync before read
v_slash_k_grad_blockwise_gemm.Run(
gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
}); // end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
v_slash_k_grad_thread_buf,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
kgrad_grid_buf);
// move slice window
s_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_k0_m_k1,
s_gemm_tile_a_block_reset_copy_step); // rewind K
s_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(
k_grid_desc_k0_n_k1,
s_gemm_tile_b_block_reset_copy_step); // rewind K and step N
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1,
Gemm2::b_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_o0_m_o1, pgrad_gemm_tile_ygrad_block_reset_copy_step); // rewind O
pgrad_gemm_tile_v_blockwise_copy.MoveSrcSliceWindow(
v_grid_desc_o0_n_o1,
pgrad_gemm_tile_v_block_reset_copy_step); // rewind O and step N
kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
q_grid_desc_m0_k_m1,
Gemm2::b_block_reset_copy_step); // rewind M
kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
// shuffle dQ and write
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
qgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
qgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2)), // M2 = MPerXdl
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2, // N2 * N3 * N4 = NPerXdl
N3,
N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
qgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatGemmAcc,
FloatCShuffle,
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
SElementwiseOperation,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
I1,
N2,
I1,
N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I2],
n_thread_data_on_block_idx[I3],
n_thread_data_on_block_idx[I4]),
s_element_op};
// shuffle: blockwise copy C from LDS to global
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
ThisThreadBlock, // ThreadGroup
CElementwiseOperation, // ElementwiseOperation,
CGlobalMemoryDataOperation, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
DataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun>
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, Gemm1NXdlPerWave, 1, 1, 1, N2, 1, N4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
1,
N2,
1,
N4>>{};
// space filling curve for shuffled blockwise C in global mem
constexpr auto sfc_c_global =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, Gemm1NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
qgrad_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global.Run(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
qgrad_grid_buf);
if constexpr(access_id < num_access - 1)
{
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
// move on C
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
qgrad_grid_desc_mblock_mperblock_kblock_kperblock, c_global_step);
}
});
}
}
};
} // namespace ck
...@@ -35,6 +35,7 @@ template <typename FloatAB, ...@@ -35,6 +35,7 @@ template <typename FloatAB,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename ZGridDesc_M_N,
typename LSEGridDesc_M, typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -97,6 +98,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
...@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -116,6 +119,65 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
// C desc for source in blockwise copy
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const ZGridDesc_M_N& z_grid_desc_m_n) ////=> for z use
{
const auto M = z_grid_desc_m_n.GetLength(I0);
const auto N = z_grid_desc_m_n.GetLength(I1);
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
z_grid_desc_m_n,
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__host__ __device__ static constexpr auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const index_t M, const index_t N) ////=> for z use
{
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(M, N)),
make_tuple(make_unmerge_transform(
make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
make_unmerge_transform(
make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{}));
}
__device__ static auto GetGemm0WaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
...@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -323,6 +385,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
using ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ZGridDesc_M_N{}))>;
struct SharedMemTrait struct SharedMemTrait
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -367,6 +432,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
unsigned short* __restrict__ p_z_grid,
FloatLSE* __restrict__ p_lse_grid, FloatLSE* __restrict__ p_lse_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -379,6 +445,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const LSEGridDesc_M& lse_grid_desc_m, const LSEGridDesc_M& lse_grid_desc_m,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
...@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -782,6 +850,79 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
///////////////////=>z for dropout
//
// z vgpr copy to global
//
// z matrix threadwise desc
constexpr auto z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4)); // registerNum
StaticBuffer<AddressSpaceEnum::Vgpr,
unsigned short,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
true>
z_tenor_buffer;
z_tenor_buffer.Clear();
// z matrix global desc
/*const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
auto z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(M, N);*/
auto z_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_z_grid, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize());
const auto wave_id = GetGemm0WaveIdx();
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
auto z_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
ushort,
ushort,
decltype(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
decltype(z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
tensor_operation::element_wise::PassThrough,
Sequence<I1, // MBlockId
I1, // NBlockID
m0, // MRepeat
n0, // NRepeat
m1, // MWaveId
n1, // NWaveId
m2, // MPerXdl
n2, // NGroupNum
n3, // NInputNum
n4>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
9, // DstVectorDim
n4, // DstScalarPerVector
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId
0, // NBlockId
0, // mrepeat
0, // nrepeat
wave_id[I0], // MWaveId
wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl
0, // group
wave_m_n_id[I0], // NInputIndex
0),
tensor_operation::element_wise::PassThrough{}};
///////////////////=>z for dropout
do do
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
...@@ -876,9 +1017,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -876,9 +1017,35 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if constexpr(IsDropout) // dropout if constexpr(IsDropout) // dropout
{ {
blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
// save z to global
if(p_z_grid)
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
decltype(z_tenor_buffer),
true>(
acc_thread_buf, ph, z_tenor_buffer);
z_thread_copy_vgpr_to_global.Run(z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
z_tenor_buffer,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_buf);
}
else
{
// P_dropped
blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf), true>(
acc_thread_buf, ph);
}
} }
//if constexpr(IsDropout) // dropout
//{
// blockwise_dropout.ApplyDropout(acc_thread_buf, ph);
//}
// TODO: may convert to log domain // TODO: may convert to log domain
running_max_new = mathext::max(max, running_max); running_max_new = mathext::max(max, running_max);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
......
...@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float ...@@ -1010,6 +1010,42 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
} }
// convert fp16 to bf16
template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
union
{
float fp32;
uint32_t int32;
} u = {static_cast<float>(x)};
return uint16_t(u.int32 >> 16);
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, half2_t>(half2_t x)
{
float y0{0}, y1{0};
bhalf2_t y{0};
asm volatile("\n \
v_cvt_f32_f16 %0, %1 \n \
"
: "=v"(y0)
: "v"(x));
asm volatile("\n \
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1\n \
"
: "=v"(y1)
: "v"(x));
asm volatile("\n \
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1] \n \
"
: "=v"(y)
: "v"(y0), "v"(y1));
return y;
}
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
......
...@@ -109,12 +109,9 @@ class philox ...@@ -109,12 +109,9 @@ class philox
__device__ uint2 u32_high_low_multi(const unsigned int a, const unsigned int b) __device__ uint2 u32_high_low_multi(const unsigned int a, const unsigned int b)
{ {
uint2* res; uint2* res;
uint2 tmp_res; unsigned long long tmp;
asm("v_mul_hi_u32 %0, %2, %3\n\t" tmp = static_cast<unsigned long long>(a) * b;
"v_mul_lo_u32 %1, %2, %3\n\t" res = reinterpret_cast<uint2*>(&tmp);
: "=v"(tmp_res.x), "=v"(tmp_res.y)
: "v"(a), "v"(b));
res = &tmp_res;
return *res; return *res;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment