Commit b67a58c0 authored by Anthony Chang's avatar Anthony Chang
Browse files

can validate dV with relaxed error tolerance

parent 8551dd43
...@@ -15,7 +15,9 @@ Outputs: ...@@ -15,7 +15,9 @@ Outputs:
*/ */
#define PRINT_HOST 1 #pragma clang diagnostic ignored "-Wunused-variable"
#define PRINT_HOST 0
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -50,6 +52,7 @@ using YElementOp = PassThrough; ...@@ -50,6 +52,7 @@ using YElementOp = PassThrough;
using DataType = F16; using DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -76,6 +79,7 @@ using DeviceGemmInstance = ...@@ -76,6 +79,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AccDataType, AccDataType,
...@@ -172,14 +176,16 @@ template <typename TensorQ, ...@@ -172,14 +176,16 @@ template <typename TensorQ,
typename TensorV, typename TensorV,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorY> typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k, const TensorK& k_g_n_k,
const TensorV& v_g_n_o, const TensorV& v_g_n_o,
const float alpha, const float alpha,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
TensorY& y_g_m_o) TensorY& y_g_m_o,
TensorLSE& lse_g_m)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]]) // [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}); auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
...@@ -230,10 +236,10 @@ int run(int argc, char* argv[]) ...@@ -230,10 +236,10 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 4; ck::index_t M = 256;
ck::index_t N = 4; ck::index_t N = 256;
ck::index_t K = 4; ck::index_t K = 256;
ck::index_t O = 4; ck::index_t O = 256;
ck::index_t G0 = 1; ck::index_t G0 = 1;
ck::index_t G1 = 1; ck::index_t G1 = 1;
...@@ -242,8 +248,6 @@ int run(int argc, char* argv[]) ...@@ -242,8 +248,6 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
const ck::index_t BatchCount = G0 * G1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -283,6 +287,8 @@ int run(int argc, char* argv[]) ...@@ -283,6 +287,8 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
...@@ -307,16 +313,27 @@ int run(int argc, char* argv[]) ...@@ -307,16 +313,27 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -340,19 +357,20 @@ int run(int argc, char* argv[]) ...@@ -340,19 +357,20 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{10}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, n] = m
} }
// calculate y beforehand // calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
...@@ -360,27 +378,36 @@ int run(int argc, char* argv[]) ...@@ -360,27 +378,36 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o); run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
ygrad_device_buf.ToDevice(y_gs_ms_os.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
// TODO ANT: attention backward kernel
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
...@@ -388,6 +415,7 @@ int run(int argc, char* argv[]) ...@@ -388,6 +415,7 @@ int run(int argc, char* argv[])
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
...@@ -402,6 +430,7 @@ int run(int argc, char* argv[]) ...@@ -402,6 +430,7 @@ int run(int argc, char* argv[])
v_gs_os_ns_strides, v_gs_os_ns_strides,
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -421,6 +450,7 @@ int run(int argc, char* argv[]) ...@@ -421,6 +450,7 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// TODO ANT: add dQ/dK/dV flops & bytes
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
...@@ -445,7 +475,7 @@ int run(int argc, char* argv[]) ...@@ -445,7 +475,7 @@ int run(int argc, char* argv[])
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 * idx[1], idx[3], idx[2]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
if(PRINT_HOST) if(PRINT_HOST)
...@@ -456,44 +486,6 @@ int run(int argc, char* argv[]) ...@@ -456,44 +486,6 @@ int run(int argc, char* argv[])
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
} }
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if 0
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// Gradients // Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
...@@ -613,7 +605,10 @@ int run(int argc, char* argv[]) ...@@ -613,7 +605,10 @@ int run(int argc, char* argv[])
kgrad_gs_ns_ks_host_result.mData); kgrad_gs_ns_ks_host_result.mData);
std::cout << "Checking vgrad:\n"; std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData, pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData); vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -50,7 +50,8 @@ template <index_t BlockSize, ...@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
......
...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax ...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
BufferType max_value_buf; BufferType max_value_buf;
BufferType sum_value_buf; BufferType sum_value_buf;
}; };
......
...@@ -18,12 +18,15 @@ ...@@ -18,12 +18,15 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename DataType,
typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -33,6 +36,7 @@ template <typename GridwiseGemm, ...@@ -33,6 +36,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename LSEGridDescriptor_M,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1, typename YGradGridDesc_M0_O_M1,
typename Block2CTileMap, typename Block2CTileMap,
...@@ -48,6 +52,7 @@ __global__ void ...@@ -48,6 +52,7 @@ __global__ void
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid, const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const DataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, DataType* __restrict__ p_kgrad_grid,
...@@ -62,6 +67,7 @@ __global__ void ...@@ -62,6 +67,7 @@ __global__ void
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m,
// const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args // const QGradGridDescriptor_M_K qgrad_grid_desc_m_k, // TODO ANT: add dQ/dK args
// const KGradGridDescriptor_N_K kgrad_grid_desc_n_k, // const KGradGridDescriptor_N_K kgrad_grid_desc_n_k,
const VGradGridDescriptor_N_O vgrad_grid_desc_n_o, const VGradGridDescriptor_N_O vgrad_grid_desc_n_o,
...@@ -87,11 +93,14 @@ __global__ void ...@@ -87,11 +93,14 @@ __global__ void
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset, p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset, p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset, p_c_grid + c_batch_offset,
p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
...@@ -106,6 +115,7 @@ __global__ void ...@@ -106,6 +115,7 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m,
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
...@@ -140,6 +150,7 @@ template <index_t NumDimG, ...@@ -140,6 +150,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
typename GemmAccDataType, typename GemmAccDataType,
...@@ -350,9 +361,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -350,9 +361,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec) v_gs_ns_os_lengths_vec, v_gs_ns_os_strides_vec)
.second; .second;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_lengths_vec: ", v_gs_os_ns_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_os_ns_strides_vec: ", v_gs_os_ns_strides_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_lengths_vec: ", v_gs_ns_os_lengths_vec, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "v_gs_ns_os_strides_vec: ", v_gs_ns_os_strides_vec, ",") << std::endl;
return PadTensorDescriptor(vgrad_desc_nraw_oraw, return PadTensorDescriptor(vgrad_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock), make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadM, padder.PadO>{}); Sequence<padder.PadN, padder.PadO>{});
} }
template <typename YGridDesc_M_O, typename Number> template <typename YGridDesc_M_O, typename Number>
...@@ -415,10 +430,36 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -415,10 +430,36 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec); return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths_vec, k_gs_ns_ks_strides_vec);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw)
{
const auto lse_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw));
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
const auto MPad = M - MRaw;
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M
return transform_tensor_descriptor(lse_grid_desc_mraw,
make_tuple(make_right_pad_transform(MRaw, MPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
}
else
{
// not pad M
return lse_grid_desc_mraw;
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
...@@ -446,11 +487,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -446,11 +487,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n) const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n) c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -474,16 +517,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -474,16 +517,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t BatchStrideLSE_;
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -496,6 +546,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -496,6 +546,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1,
CGridDesc_M_N, CGridDesc_M_N,
LSEGridDesc_M,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -552,6 +603,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -552,6 +603,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const DataType* p_b_grid, const DataType* p_b_grid,
const DataType* p_b1_grid, const DataType* p_b1_grid,
const DataType* p_c_grid, // for dS const DataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid, const DataType* p_ygrad_grid,
DataType* p_qgrad_grid, DataType* p_qgrad_grid,
DataType* p_kgrad_grid, DataType* p_kgrad_grid,
...@@ -566,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -566,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -581,6 +634,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -581,6 +634,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_lse_grid_{p_lse_grid},
p_ygrad_grid_{p_ygrad_grid},
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
...@@ -590,9 +647,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -590,9 +647,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
// dV = P^T * dY // dV = P^T * dY
vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(c_gs_ms_gemm1ns_lengths, vgrad_grid_desc_n_o_{DeviceOp::MakeVGradGridDescriptor_N_O(
c_gs_ms_gemm1ns_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
/* PTrans descriptor will be constructed in kernel */ /* PTrans descriptor will be constructed in kernel */
ygrad_grid_desc_m0_o_m1_{ ygrad_grid_desc_m0_o_m1_{
DeviceOp::MakeYGradGridDescriptor_M0_O_M1(c_grid_desc_m_n_, Number<Y_M1>{})}, DeviceOp::MakeYGradGridDescriptor_M0_O_M1(c_grid_desc_m_n_, Number<Y_M1>{})},
...@@ -627,7 +685,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -627,7 +685,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_} a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_, type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc0_biases; ignore = p_acc0_biases;
...@@ -647,6 +705,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -647,6 +705,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n_); c_grid_desc_m_n_);
} }
Print();
} }
void Print() const void Print() const
...@@ -659,14 +718,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -659,14 +718,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print(); // b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print(); // b1_grid_desc_g_n_k_.Print();
std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", " std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
<< c_grid_desc_g_m_n_.GetLength(I1) << ", " << c_grid_desc_g_m_n_.GetLength(I1) << ", "
<< c_grid_desc_g_m_n_.GetLength(I2) << '\n'; << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// c_grid_desc_g_m_n_.Print(); // c_grid_desc_g_m_n_.Print();
std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", " << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
} }
// pointers // pointers
...@@ -674,6 +737,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -674,6 +737,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const DataType* p_b_grid_; const DataType* p_b_grid_;
const DataType* p_b1_grid_; const DataType* p_b1_grid_;
const DataType* p_c_grid_; const DataType* p_c_grid_;
const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const DataType* p_ygrad_grid_;
DataType* p_vgrad_grid_; DataType* p_vgrad_grid_;
DataType* p_qgrad_grid_; DataType* p_qgrad_grid_;
...@@ -684,6 +748,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -684,6 +748,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
LSEGridDesc_M lse_grid_desc_m_;
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
...@@ -732,7 +797,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -732,7 +797,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_;
std::cout << "grid size = " << grid_size << '\n';
// Gemm0_K // Gemm0_K
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -743,6 +808,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -743,6 +808,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -752,6 +818,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -752,6 +818,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::LSEGridDesc_M,
DeviceOp::VGradGridDesc_N_O, DeviceOp::VGradGridDesc_N_O,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::YGradGridDesc_M0_O_M1,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
...@@ -768,6 +835,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -768,6 +835,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.p_b_grid_, arg.p_b_grid_,
arg.p_b1_grid_, arg.p_b1_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.p_lse_grid_,
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
...@@ -781,6 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -781,6 +849,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.lse_grid_desc_m_,
arg.vgrad_grid_desc_n_o_, arg.vgrad_grid_desc_n_o_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
...@@ -791,6 +860,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -791,6 +860,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
#if 1
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
ave_time = launch_kernel(integral_constant<bool, true>{}); ave_time = launch_kernel(integral_constant<bool, true>{});
...@@ -799,7 +869,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -799,7 +869,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
{ {
ave_time = launch_kernel(integral_constant<bool, false>{}); ave_time = launch_kernel(integral_constant<bool, false>{});
} }
#endif
return ave_time; return ave_time;
} }
...@@ -898,6 +968,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -898,6 +968,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const DataType* p_b, const DataType* p_b,
const DataType* p_b1, const DataType* p_b1,
const DataType* p_c, const DataType* p_c,
const LSEDataType* p_lse,
const DataType* p_ygrad_grid, const DataType* p_ygrad_grid,
DataType* p_qgrad_grid, DataType* p_qgrad_grid,
DataType* p_kgrad_grid, DataType* p_kgrad_grid,
...@@ -912,6 +983,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -912,6 +983,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -928,6 +1000,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -928,6 +1000,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
p_b, p_b,
p_b1, p_b1,
p_c, p_c,
p_lse,
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
...@@ -942,6 +1015,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -942,6 +1015,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
...@@ -962,6 +1036,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -962,6 +1036,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const void* p_b, const void* p_b,
const void* p_b1, const void* p_b1,
const void* p_c, const void* p_c,
const void* p_lse,
const void* p_ygrad_grid, const void* p_ygrad_grid,
void* p_qgrad_grid, void* p_qgrad_grid,
void* p_kgrad_grid, void* p_kgrad_grid,
...@@ -976,6 +1051,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -976,6 +1051,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector<index_t>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumAcc1Bias>
...@@ -992,6 +1068,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -992,6 +1068,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
static_cast<const DataType*>(p_b), static_cast<const DataType*>(p_b),
static_cast<const DataType*>(p_b1), static_cast<const DataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const DataType*>(p_c),
static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid), static_cast<const DataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid), static_cast<DataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid), static_cast<DataType*>(p_kgrad_grid),
...@@ -1006,6 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -1006,6 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths,
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides, acc0_biases_gs_ms_ns_strides,
acc1_biases_gs_ms_gemm1ns_lengths, acc1_biases_gs_ms_gemm1ns_lengths,
......
...@@ -21,6 +21,7 @@ namespace ck { ...@@ -21,6 +21,7 @@ namespace ck {
template <typename DataType, template <typename DataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatLSE,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename AccElementwiseOperation, typename AccElementwiseOperation,
...@@ -31,6 +32,7 @@ template <typename DataType, ...@@ -31,6 +32,7 @@ template <typename DataType,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDesc_M_N, typename CGridDesc_M_N,
typename LSEGridDesc_M,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -170,7 +172,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -170,7 +172,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr index_t n = Free0_N - 1; constexpr index_t n = Free0_N - 1;
constexpr index_t n2 = n % NPerXdl; constexpr index_t n2 = n % NPerXdl;
constexpr index_t n1 = n / NPerXdl % Gemm0NWaves; constexpr index_t n1 = n / NPerXdl % Gemm0NWaves;
constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % MXdlPerWave; constexpr index_t n0 = n / NPerXdl / Gemm0NWaves % NXdlPerWave;
// assume 256 decomposed into 2 x 4 x 32 // assume 256 decomposed into 2 x 4 x 32
// 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32 // 1d idx ( 32 - 1) -> 3d idx 0, 0, 31 -> 3d dim 1 x 1 x 32
...@@ -178,6 +180,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -178,6 +180,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{}; return Sequence<m0, n0, m1, n1, m2, n2>{} + Sequence<1, 1, 1, 1, 1, 1>{};
} }
__host__ __device__ static constexpr auto GetPBlockSliceLengths_M0_N0_M1_N1()
{
return generate_sequence_v2(
[](auto I) { return GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2().At(I); },
Number<4>{});
}
// template <typename PBlockDesc_M0_N_M1> // template <typename PBlockDesc_M0_N_M1>
// __host__ __device__ static constexpr auto // __host__ __device__ static constexpr auto
// MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&) // MakePMmaTileDescriptor_N0_N1_N2_M(const PBlockDesc_M0_N_M1&)
...@@ -303,13 +312,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -303,13 +312,21 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t gemm1_bytes_end = const index_t gemm1_bytes_end =
(SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) *
sizeof(DataType); sizeof(DataType);
const index_t vgrad_gemm_bytes_end = (SharedMemTrait::p_block_space_size_aligned +
SharedMemTrait::ygrad_block_space_size_aligned) *
sizeof(DataType);
const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
SharedMemTrait::reduction_space_size_aligned) * SharedMemTrait::reduction_space_size_aligned) *
sizeof(FloatGemmAcc); sizeof(FloatGemmAcc);
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); return math::max(gemm0_bytes_end,
gemm1_bytes_end,
vgrad_gemm_bytes_end,
softmax_bytes_end,
c_block_bytes_end);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -395,6 +412,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -395,6 +412,22 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return c_grid_desc_mblock_mperblock_nblock_nperblock; return c_grid_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr auto
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(const LSEGridDesc_M& lse_grid_desc_m)
{
const index_t M = lse_grid_desc_m.GetLength(I0);
const index_t MBlock = M / MPerBlock;
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
const auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl = transform_tensor_descriptor(
lse_grid_desc_m,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MXdlPerWave>{}, MWave, Number<MPerXdl>{}))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2, 3>{}));
return lse_grid_desc_mblock_mrepeat_mwave_mperxdl;
}
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
...@@ -418,8 +451,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -418,8 +451,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto b1_block_desc_bk0_n_bk1 = static constexpr auto b1_block_desc_bk0_n_bk1 =
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
static constexpr auto p_block_desc_m0_n_m1 =
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
static constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); static constexpr auto max_lds_align = Number<16 / sizeof(DataType)>{};
static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( static constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -427,10 +464,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -427,10 +464,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple(
b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
static constexpr auto p_block_space_size_aligned =
math::integer_least_multiple(p_block_desc_m0_n_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto ygrad_block_space_size_aligned = math::integer_least_multiple(
ygrad_block_desc_m0_o_m1.GetElementSpaceSize(), max_lds_align);
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; static constexpr auto b_block_space_offset = a_block_space_size_aligned.value;
static constexpr auto b1_block_space_offset = 0; static constexpr auto b1_block_space_offset = 0;
static constexpr auto p_block_space_offset = 0;
static constexpr auto ygrad_block_space_offset = p_block_space_size_aligned.value;
// LDS allocation for reduction // LDS allocation for reduction
static constexpr index_t reduction_space_size_aligned = static constexpr index_t reduction_space_size_aligned =
...@@ -454,6 +497,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -454,6 +497,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
const DataType* __restrict__ p_b1_grid, const DataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const DataType* __restrict__ p_c_grid,
const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const DataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, DataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, DataType* __restrict__ p_kgrad_grid,
...@@ -469,6 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -469,6 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDesc_M& lse_grid_desc_m,
const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o, const VGradGridDescriptor_N_O& vgrad_grid_desc_n_o,
const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1& ygrad_grid_desc_m0_o_m1,
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
...@@ -482,6 +527,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -482,6 +527,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto lse_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_lse_grid, lse_grid_desc_m.GetElementSpaceSize());
const auto ygrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ygrad_grid, ygrad_grid_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_vgrad_grid, vgrad_grid_desc_n_o.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -830,6 +881,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -830,6 +881,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = NumericLimits<FloatGemmAcc>::Lowest(); running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest(); running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
auto lse_grid_desc_mblock_mrepeat_mwave_mperxdl =
MakeLSEGridDescriptor_MBlock_MRepeat_NWave_MPerXdl(lse_grid_desc_m);
constexpr auto lse_thread_desc_mblock_mrepeat_mwave_mperxdl =
make_naive_tensor_descriptor_packed(make_tuple(I1, m0, m1, m2));
auto lse_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatLSE>(
lse_thread_desc_mblock_mrepeat_mwave_mperxdl.GetElementSpaceSize());
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
auto lse_thread_copy_global_to_vgpr =
ThreadwiseTensorSliceTransfer_v2<FloatLSE,
FloatLSE,
decltype(lse_grid_desc_mblock_mrepeat_mwave_mperxdl),
decltype(lse_thread_desc_mblock_mrepeat_mwave_mperxdl),
Sequence<1, m0, m1, m2>,
Sequence<0, 1, 2, 3>,
3,
m2,
1,
false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock
acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl
// //
// dV // dV
// //
...@@ -837,11 +917,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -837,11 +917,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1 // P vgpr to lds: writes vgprs of a subgroup to LDS and transform into m0_n_m1
// m0, n0 are m/n repeat per wave // m0, n0 are m/n repeat per wave
// m1, n1 are number of waves // m1, n1 are number of waves
constexpr auto p_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = constexpr auto p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto p_dst_block_desc_m0_n_m1 = constexpr auto p_block_desc_m0_n_m1 = VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
VGradGemmTile_N_O_M::GetPBlockDescriptor_M0_N_M1();
constexpr auto p_block_lengths = constexpr auto p_block_lengths =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
...@@ -854,30 +933,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -854,30 +933,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto P_N3 = p_block_lengths[I6]; constexpr auto P_N3 = p_block_lengths[I6];
constexpr auto P_N4 = p_block_lengths[I7]; constexpr auto P_N4 = p_block_lengths[I7];
constexpr auto p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = [&]() constexpr constexpr auto p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = [&]() constexpr
{ {
constexpr auto p_dst_block_desc_m_n = transform_tensor_descriptor( constexpr auto p_block_desc_m_n = transform_tensor_descriptor(
p_dst_block_desc_m0_n_m1, p_block_desc_m0_n_m1,
make_tuple(make_merge_transform_v3_division_mod( make_tuple(make_merge_transform_v3_division_mod(
make_tuple(VGradGemmTile_N_O_M::P_M0, VGradGemmTile_N_O_M::P_M1)), make_tuple(VGradGemmTile_N_O_M::P_M0, VGradGemmTile_N_O_M::P_M1)),
make_pass_through_transform(VGradGemmTile_N_O_M::Free0_N)), make_pass_through_transform(VGradGemmTile_N_O_M::Free0_N)),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
return transform_tensor_descriptor( return transform_tensor_descriptor(
p_dst_block_desc_m_n, p_block_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(P_M0, P_M1, P_M2)), make_tuple(make_unmerge_transform(make_tuple(I1, P_M1, P_M2)),
make_unmerge_transform(make_tuple(P_N0, P_N1, P_N2, P_N3, P_N4))), make_unmerge_transform(make_tuple(I1, P_N1, P_N2, P_N3, P_N4))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
} }
(); ();
// TODO ANT: check lds offset const auto p_thread_origin_nd_idx_on_block = [&]() {
auto p_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared), p_dst_block_desc_m0_n_m1.GetElementSpaceSize());
const auto p_dst_thread_origin = [&]() {
const auto c_thread_mtx_on_block = const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
...@@ -904,8 +981,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -904,8 +981,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block)); make_multi_index(n_thread_data_on_block));
return make_tuple(0, // mrepeat return make_tuple(m_thread_data_on_block_idx[I0], // mrepeat
0, // nrepeat n_thread_data_on_block_idx[I0], // nrepeat
m_thread_data_on_block_idx[I1], // mwave m_thread_data_on_block_idx[I1], // mwave
n_thread_data_on_block_idx[I1], // nwave n_thread_data_on_block_idx[I1], // nwave
m_thread_data_on_block_idx[I2], // xdlops m_thread_data_on_block_idx[I2], // xdlops
...@@ -914,19 +991,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -914,19 +991,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
n_thread_data_on_block_idx[I4]); n_thread_data_on_block_idx[I4]);
}(); }();
constexpr auto p_block_slice_lengths_m0_n0_m1_n1_m2_n2 = // mrepeat, nrepeat, mwaves, constexpr auto p_block_slice_lengths_m0_n0_m1_n1 =
// nwaves, mperxdl, nperxdl VGradGemmTile_N_O_M::GetPBlockSliceLengths_M0_N0_M1_N1(); // mrepeat, nrepeat,
VGradGemmTile_N_O_M::GetPBlockSliceLengths_M0_N0_M1_N1_M2_N2(); // mwaves, nwaves,
// how to properly perform copy for a sub-workgroup? // how to properly perform copy for a sub-workgroup?
auto p_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto p_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, DataType,
decltype(p_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
decltype(p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4), decltype(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I0], // ThreadSliceLengths Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0], // ThreadSliceLengths
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I1], p_block_slice_lengths_m0_n0_m1_n1[I1],
I1, I1,
I1, I1,
I1, I1,
...@@ -939,21 +1016,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -939,21 +1016,35 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{ true>{
p_dst_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_multi_index(p_dst_thread_origin[I0], make_multi_index(
p_dst_thread_origin[I1], p_thread_origin_nd_idx_on_block[I0],
p_dst_thread_origin[I2] % p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I2], p_thread_origin_nd_idx_on_block[I1],
p_dst_thread_origin[I3] % p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I3], p_thread_origin_nd_idx_on_block[I2] % p_block_slice_lengths_m0_n0_m1_n1[I2],
p_dst_thread_origin[I4], p_thread_origin_nd_idx_on_block[I3] % p_block_slice_lengths_m0_n0_m1_n1[I3],
p_dst_thread_origin[I5], p_thread_origin_nd_idx_on_block[I4],
p_dst_thread_origin[I6], p_thread_origin_nd_idx_on_block[I5],
p_dst_thread_origin[I7]), p_thread_origin_nd_idx_on_block[I6],
p_thread_origin_nd_idx_on_block[I7]),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// construct space filling curve // Sequence<p_block_slice_lengths_m0_n0_m1_n1[I0],
// p_thread_copy_vgpr_to_lds.Run(); // p_block_slice_lengths_m0_n0_m1_n1[I1],
// I1,
constexpr auto ygrad_dst_block_desc_m0_o_m1 = // I1,
// I1,
// P_N2,
// I1,
// P_N4>{}
// .foo();
// 1, 4, 1, 1, 1, 4, 1, 4
constexpr auto sfc_p_m0_n0_m1_n1_m2_n2 =
SpaceFillingCurve<Sequence<P_M0, P_N0, P_M1, P_N1>,
Sequence<0, 1, 2, 3>,
decltype(p_block_slice_lengths_m0_n0_m1_n1)>{};
constexpr auto ygrad_block_desc_m0_o_m1 =
VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1(); VGradGemmTile_N_O_M::GetYGradBlockDescriptor_M0_O_M1();
auto ygrad_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< auto ygrad_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
...@@ -967,7 +1058,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -967,7 +1058,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
DataType, DataType,
DataType, DataType,
decltype(ygrad_grid_desc_m0_o_m1), decltype(ygrad_grid_desc_m0_o_m1),
decltype(ygrad_dst_block_desc_m0_o_m1), decltype(ygrad_block_desc_m0_o_m1),
typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder, // access order == thread typename VGradGemmTile_N_O_M::YGrad_ThreadClusterArrangeOrder, // access order == thread
// order // order
Sequence<1, 0, 2>, Sequence<1, 0, 2>,
...@@ -980,113 +1071,165 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -980,113 +1071,165 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
true, true,
true, true,
1>(ygrad_grid_desc_m0_o_m1, 1>(ygrad_grid_desc_m0_o_m1,
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), make_multi_index(m_block_data_idx_on_grid / VGradGemmTile_N_O_M::YGrad_M1,
gemm1_n_block_data_idx_on_grid,
0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
ygrad_dst_block_desc_m0_o_m1, ygrad_block_desc_m0_o_m1,
make_multi_index(0, 0, 0), make_multi_index(0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto p_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::p_block_space_offset,
p_block_desc_m0_n_m1.GetElementSpaceSize());
auto ygrad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + SharedMemTrait::ygrad_block_space_offset,
ygrad_block_desc_m0_o_m1.GetElementSpaceSize());
auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto vgrad_blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize, BlockSize,
DataType, DataType,
FloatGemmAcc, FloatGemmAcc,
decltype(p_dst_block_desc_m0_n_m1), decltype(p_block_desc_m0_n_m1),
decltype(ygrad_dst_block_desc_m0_o_m1), decltype(ygrad_block_desc_m0_o_m1),
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
VGradGemmTile_N_O_M::GemmNRepeat, // NRepeat VGradGemmTile_N_O_M::GemmNRepeat,
VGradGemmTile_N_O_M::GemmORepeat, // ORepeat VGradGemmTile_N_O_M::GemmORepeat,
VGradGemmTile_N_O_M::GemmMPack>{}; VGradGemmTile_N_O_M::GemmMPack,
true>{}; // TranspossC
constexpr auto vgrad_block_lengths = auto vgrad_acc_thread_buf = vgrad_blockwise_gemm.GetCThreadBuffer();
vgrad_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2().GetLengths();
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor( const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2 = transform_tensor_descriptor(
vgrad_grid_desc_n_o, vgrad_grid_desc_n_o,
make_tuple( make_tuple(
make_unmerge_transform(make_tuple(I1, // may place a dummy variable
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I2],
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I4])),
make_unmerge_transform(make_tuple(I1, make_unmerge_transform(make_tuple(I1,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I3], VGradGemmTile_N_O_M::GemmNWave,
p_block_slice_lengths_m0_n0_m1_n1_m2_n2[I5]))), MPerXdl)),
make_unmerge_transform(make_tuple(I1,
VGradGemmTile_N_O_M::GemmOWave,
NPerXdl))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
constexpr auto vgrad_thread_desc_n0_o0_n1_o1_n2_n3_n4_o2 = constexpr auto vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
const auto vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2 = const auto vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
vgrad_blockwise_gemm.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( vgrad_blockwise_gemm.xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2); vgrad_grid_desc_n0_o0_n1_o1_n2_o2);
const auto vgrad_thread_mtx_on_block_n_o = const auto vgrad_thread_mtx_on_block_n_o =
vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); vgrad_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
constexpr auto vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2 = constexpr auto vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4 =
decltype(vgrad_blockwise_gemm)::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); decltype(vgrad_blockwise_gemm)::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
constexpr auto VGrad_N0 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I0); constexpr auto VGrad_N0 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I0);
constexpr auto VGrad_O0 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I1); constexpr auto VGrad_O0 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I1);
constexpr auto VGrad_N1 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I2); constexpr auto VGrad_N1 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I2);
constexpr auto VGrad_O1 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I3); constexpr auto VGrad_O1 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I3);
constexpr auto VGrad_N2 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I4); constexpr auto VGrad_N2 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I4);
constexpr auto VGrad_N3 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I5); constexpr auto VGrad_O2 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I5);
constexpr auto VGrad_N4 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I6); constexpr auto VGrad_O3 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I6);
constexpr auto VGrad_O2 = vgrad_block_desc_n0_o0_n1_o1_n2_n3_n4_o2.GetLength(I7); constexpr auto VGrad_O4 = vgrad_block_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLength(I7);
const index_t n_thread_data_idx_on_grid = const index_t n_thread_data_idx_on_grid = vgrad_thread_mtx_on_block_n_o[I0];
vgrad_thread_mtx_on_block_n_o[I0]; // TODO ANT: step n after each Gemm1 outer loop
const index_t o_thread_data_idx_on_grid = const index_t o_thread_data_idx_on_grid =
vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid; vgrad_thread_mtx_on_block_n_o[I1] + gemm1_n_block_data_idx_on_grid;
const auto n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor = const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform( make_tuple(make_merge_transform(
make_tuple(VGrad_N0, VGrad_N1, VGrad_N2, VGrad_N3, VGrad_N4))), make_tuple(VGrad_N0, VGrad_N1, VGrad_N2))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
const auto n_thread_data_nd_idx_on_grid = const auto n_thread_data_nd_idx_on_grid =
n_thread_data_on_grid_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_idx_on_grid)); make_multi_index(n_thread_data_idx_on_grid));
const auto o_thread_data_on_grid_to_o0_o1_o2_adaptor = make_single_stage_tensor_adaptor( const auto o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor =
make_tuple(make_merge_transform(make_tuple(VGrad_O0, VGrad_O1, VGrad_O2))), make_single_stage_tensor_adaptor(
make_tuple(Sequence<0, 1, 2>{}), make_tuple(make_merge_transform(
make_tuple(Sequence<0>{})); make_tuple(VGrad_O0, VGrad_O1, VGrad_O2, VGrad_O3, VGrad_O4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto o_thread_data_nd_idx_on_grid = const auto o_thread_data_nd_idx_on_grid =
o_thread_data_on_grid_to_o0_o1_o2_adaptor.CalculateBottomIndex( o_thread_data_on_grid_to_o0_o1_o2_o3_o4_adaptor.CalculateBottomIndex(
make_multi_index(o_thread_data_idx_on_grid)); make_multi_index(o_thread_data_idx_on_grid));
auto vgrad_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< auto vgrad_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, DataType,
decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_n3_n4_o2), decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2), decltype(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4),
tensor_operation::element_wise::PassThrough, // CElementwiseOperation tensor_operation::element_wise::PassThrough, // CElementwiseOperation
decltype(vgrad_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() decltype(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4.GetLengths()), // SliceLengths
.GetLengths()), // SliceLengths
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder Sequence<0, 1, 2, 3, 4, 5, 6, 7>, // AccessOrder
7, // VectorDim 7, // VectorDim
1, // ScalarPerVector 2, // ScalarPerVector
InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation InMemoryDataOperationEnum::AtomicAdd, // GlobalMemoryDataOperation
1, 1, // DstScalarStrideInVector
true>(vgrad_grid_desc_n0_o0_n1_o1_n2_n3_n4_o2, true>(vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
make_multi_index(n_thread_data_nd_idx_on_grid[I0], make_multi_index(n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0], o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1], n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1], o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2], n_thread_data_nd_idx_on_grid[I2],
n_thread_data_nd_idx_on_grid[I3], o_thread_data_nd_idx_on_grid[I2],
n_thread_data_nd_idx_on_grid[I4], o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I2]), o_thread_data_nd_idx_on_grid[I4]),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
// TODO ANT: ygrad slice window step size
#if 0 #if 0
if(hipThreadIdx_x % 32 < 4)
{
printf("wid %zd tid %zd _n0_o0_n1_o1_n2_o2_o3_o4 %d %d %d %d %d %d %d %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
n_thread_data_nd_idx_on_grid[I0],
o_thread_data_nd_idx_on_grid[I0],
n_thread_data_nd_idx_on_grid[I1],
o_thread_data_nd_idx_on_grid[I1],
n_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I2],
o_thread_data_nd_idx_on_grid[I3],
o_thread_data_nd_idx_on_grid[I4]);
}
#endif
// p_thread_slice_copy_step will be in for loop
constexpr auto ygrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::YGrad_M0, 0, 0);
constexpr auto ygrad_block_reset_copy_step =
make_multi_index(-MPerBlock / VGradGemmTile_N_O_M::YGrad_M1, 0, 0);
// vgrad gemm output tile
const auto vgrad_block_slice_copy_step =
make_multi_index(VGradGemmTile_N_O_M::GemmNRepeat, 0, 0, 0, 0, 0, 0, 0);
#if 0
if(hipThreadIdx_x == 0)
{
printf("bid %zd, n_grid = %d, o_grid = %d, step N0 = %d\n",
hipBlockIdx_x,
n_thread_data_idx_on_grid,
o_thread_data_idx_on_grid,
n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(NPerBlock))[I0]);
}
#endif
constexpr index_t num_vgrad_gemm_loop = MPerBlock / VGradGemmTile_N_O_M::Sum_M;
lse_thread_copy_global_to_vgpr.Run(lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
lse_grid_buf,
lse_thread_desc_mblock_mrepeat_mwave_mperxdl,
make_tuple(I0, I0, I0, I0),
lse_thread_buf);
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
do do
...@@ -1178,131 +1321,123 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1178,131 +1321,123 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
// softmax #if 0
SoftmaxBuf& max = blockwise_softmax.max_value_buf; if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; {
printf("tid %zd, S[0:3] = %f, %f, %f, %f\n",
blockwise_softmax.Run(acc_thread_buf, workspace_buf); hipThreadIdx_x,
acc_thread_buf[I0],
acc_thread_buf[I1],
acc_thread_buf[I2],
acc_thread_buf[I3]);
}
#endif
// TODO: may convert to log domain // softmax
running_max_new = mathext::max(max, running_max); blockwise_softmax.RunWithPreCalcStats(acc_thread_buf, lse_thread_buf);
running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
mathext::exp(max - running_max_new) * sum;
// gemm1 #if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{ {
// TODO: explore using dynamic buffer for a1 thread buffer printf("tid %zd, P[0:3] = %f, %f, %f, %f\n",
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), hipThreadIdx_x,
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that acc_thread_buf[I0],
// the A1 source buffer is static buffer holding the output of first GEMM and acc_thread_buf[I1],
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset acc_thread_buf[I2],
// explicitly in Run() below. acc_thread_buf[I3]);
}
#endif
// Initialize acc1 block_sync_lds(); // wait for gemm1 LDS read
acc1_thread_buf.Clear();
// preload data into LDS SubThreadBlock<BlockSize> p_thread_copy_subgroup(blockwise_gemm.GetWaveIdx()[I0],
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); blockwise_gemm.GetWaveIdx()[I1]);
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, static_assert(sfc_p_m0_n0_m1_n1_m2_n2.GetNumOfAccess() == num_vgrad_gemm_loop, "");
b1_block_slice_copy_step);
block_sync_lds(); // wait for reduction LDS read vgrad_acc_thread_buf.Clear();
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); // TODO ANT: single buffer prefetch pipeline
static_for<0, num_vgrad_gemm_loop, 1>{}([&](auto vgrad_gemm_loop_idx) { // gemm dV
// load VGrad Gemm B
ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1, ygrad_grid_buf);
// main body // load VGrad Gemm A
if constexpr(num_gemm1_k_block_inner_loop > 1) const auto p_nd_idx =
sfc_p_m0_n0_m1_n1_m2_n2.GetIndexTupleOfNumber(vgrad_gemm_loop_idx);
constexpr auto mwave_range = make_tuple(
p_nd_idx[I2], p_nd_idx[I2] + p_block_slice_lengths_m0_n0_m1_n1[I2]);
constexpr auto nwave_range = make_tuple(
p_nd_idx[I3], p_nd_idx[I3] + p_block_slice_lengths_m0_n0_m1_n1[I3]);
if (p_thread_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { p_thread_copy_vgpr_to_lds.Run(
a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, p_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
make_tuple(Number<i * A1ThreadSliceK0>{}, I0, I0), make_tuple(
acc_thread_buf, p_nd_idx[I0], p_nd_idx[I1], I0, I0, I0, I0, I0, I0),
a1_thread_desc_k0_m_k1, acc_thread_buf,
make_tuple(I0, I0, I0), p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
a1_thread_buf); p_block_buf);
}
b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
block_sync_lds();
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
block_sync_lds();
b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
b1_block_slice_copy_step); // p slice window is moved by loop index
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1,
ygrad_block_slice_copy_step);
b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); block_sync_lds(); // sync before write
}); ygrad_blockwise_copy.RunWrite(ygrad_block_desc_m0_o_m1, ygrad_block_buf);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(
p_block_buf.p_data_,
index_t(p_block_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetElementSpaceSize()));
} }
// tail #endif
#if 0
if (hipBlockIdx_x == 0)
{ {
a1_blockwise_copy.Run( debug::print_shared(ygrad_block_buf.p_data_,
acc_thread_desc_k0_m_k1, index_t(ygrad_block_desc_m0_o_m1.GetElementSpaceSize()));
make_tuple( }
Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), #endif
acc_thread_buf,
a1_thread_desc_k0_m_k1,
make_tuple(I0, I0, I0),
a1_thread_buf);
block_sync_lds(); block_sync_lds(); // sync before read
vgrad_blockwise_gemm.Run(p_block_buf, ygrad_block_buf, vgrad_acc_thread_buf);
gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); #if 1
if(hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 4)
{
printf("outer %d inner %d tid %zd, dV[0:3] = %f, %f, %f, %f\n",
gemm1_k_block_outer_index,
vgrad_gemm_loop_idx.value,
hipThreadIdx_x,
vgrad_acc_thread_buf[I0],
vgrad_acc_thread_buf[I1],
vgrad_acc_thread_buf[I2],
vgrad_acc_thread_buf[I3]);
} }
} // end gemm1 #endif
}); // end gemm dV
// workaround compiler issue; see ck/ck.hpp
if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 &&
is_same_v<FloatAB, bhalf_t> && MPerBlock == 256 && NPerBlock == 128 &&
Gemm1NPerBlock == 128)
{
__builtin_amdgcn_sched_barrier(0);
}
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = // atomic_add vgrad
gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); vgrad_thread_copy_vgpr_to_global.Run(vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); vgrad_acc_thread_buf,
constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); vgrad_grid_buf);
constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
a_block_reset_copy_step); // rewind K a_block_reset_copy_step); // rewind K
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
b_block_reset_copy_step); // rewind K and step N b_block_reset_copy_step); // rewind K and step N
ygrad_blockwise_copy.MoveSrcSliceWindow(ygrad_grid_desc_m0_o_m1,
ygrad_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, vgrad_block_slice_copy_step); // step N
// update before next j iteration
running_max = running_max_new;
running_sum = running_sum_new;
block_sync_lds(); // wait for gemm1 LDS read
} while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
#endif
// TODO ANT: // TODO ANT:
// shuffle dQ and write // shuffle dQ and write
......
...@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData v; SrcData v;
......
...@@ -28,8 +28,8 @@ struct SubThreadBlock ...@@ -28,8 +28,8 @@ struct SubThreadBlock
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename Tuple2> template <typename TupleArg1, typename TupleArg2>
__device__ constexpr bool IsBelong(const Tuple2& mwave_range, const Tuple2& nwave_range) __device__ constexpr bool IsBelong(const TupleArg1& mwave_range, const TupleArg2& nwave_range)
{ {
// wave_range[I0] inclusive, wave_range[I1] exclusive // wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0]) if(mwave_ < mwave_range[I0])
......
...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
ck::type_convert<AccDataType>( ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) + arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx); arg.beta_ * self(idx);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
}); });
return 0; return 0;
......
...@@ -148,7 +148,7 @@ check_err(const Range& out, ...@@ -148,7 +148,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
add_test_executable(test_space_filling_curve space_filling_curve.cpp) add_test_executable(test_space_filling_curve space_filling_curve.cpp)
add_test_executable(test_threadwise_copy test_threadwise_copy.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment