Commit b67a58c0 authored by Anthony Chang's avatar Anthony Chang
Browse files

can validate dV with relaxed error tolerance

parent 8551dd43
...@@ -15,7 +15,9 @@ Outputs: ...@@ -15,7 +15,9 @@ Outputs:
*/ */
#define PRINT_HOST 1 #pragma clang diagnostic ignored "-Wunused-variable"
#define PRINT_HOST 0
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -50,6 +52,7 @@ using YElementOp = PassThrough; ...@@ -50,6 +52,7 @@ using YElementOp = PassThrough;
using DataType = F16; using DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32;
using Acc0BiasDataType = ck::Tuple<>; using Acc0BiasDataType = ck::Tuple<>;
using Acc1BiasDataType = ck::Tuple<>; using Acc1BiasDataType = ck::Tuple<>;
...@@ -76,6 +79,7 @@ using DeviceGemmInstance = ...@@ -76,6 +79,7 @@ using DeviceGemmInstance =
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
Acc1BiasDataType, Acc1BiasDataType,
AccDataType, AccDataType,
...@@ -172,14 +176,16 @@ template <typename TensorQ, ...@@ -172,14 +176,16 @@ template <typename TensorQ,
typename TensorV, typename TensorV,
typename TensorS, typename TensorS,
typename TensorP, typename TensorP,
typename TensorY> typename TensorY,
typename TensorLSE = TensorP>
void run_attention_fwd_host(const TensorQ& q_g_m_k, void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k, const TensorK& k_g_n_k,
const TensorV& v_g_n_o, const TensorV& v_g_n_o,
const float alpha, const float alpha,
TensorS& s_g_m_n, TensorS& s_g_m_n,
TensorP& p_g_m_n, TensorP& p_g_m_n,
TensorY& y_g_m_o) TensorY& y_g_m_o,
TensorLSE& lse_g_m)
{ {
// S = alpha * Q * K^T // S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1}); auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
...@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -207,7 +213,7 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]]) // [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}); auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2}, &lse_g_m);
ref_softmax_invoker.Run(ref_softmax_argument); ref_softmax_invoker.Run(ref_softmax_argument);
...@@ -230,10 +236,10 @@ int run(int argc, char* argv[]) ...@@ -230,10 +236,10 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 4; ck::index_t M = 256;
ck::index_t N = 4; ck::index_t N = 256;
ck::index_t K = 4; ck::index_t K = 256;
ck::index_t O = 4; ck::index_t O = 256;
ck::index_t G0 = 1; ck::index_t G0 = 1;
ck::index_t G1 = 1; ck::index_t G1 = 1;
...@@ -242,8 +248,6 @@ int run(int argc, char* argv[]) ...@@ -242,8 +248,6 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
const ck::index_t BatchCount = G0 * G1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -283,6 +287,8 @@ int run(int argc, char* argv[]) ...@@ -283,6 +287,8 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
const ck::index_t BatchCount = G0 * G1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
...@@ -307,16 +313,27 @@ int run(int argc, char* argv[]) ...@@ -307,16 +313,27 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -340,19 +357,20 @@ int run(int argc, char* argv[]) ...@@ -340,19 +357,20 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{10}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, n] = m
} }
// calculate y beforehand // calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<DataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
...@@ -360,27 +378,36 @@ int run(int argc, char* argv[]) ...@@ -360,27 +378,36 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o); run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
// TODO ANT: make sure K/V gradients are zeroed
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
ygrad_device_buf.ToDevice(y_gs_ms_os.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero();
// TODO ANT: attention backward kernel
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
...@@ -388,6 +415,7 @@ int run(int argc, char* argv[]) ...@@ -388,6 +415,7 @@ int run(int argc, char* argv[])
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()),
...@@ -402,6 +430,7 @@ int run(int argc, char* argv[]) ...@@ -402,6 +430,7 @@ int run(int argc, char* argv[])
v_gs_os_ns_strides, v_gs_os_ns_strides,
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -421,6 +450,7 @@ int run(int argc, char* argv[]) ...@@ -421,6 +450,7 @@ int run(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// TODO ANT: add dQ/dK/dV flops & bytes
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(DataType) * N * O + sizeof(DataType) * M * O) *
...@@ -445,7 +475,7 @@ int run(int argc, char* argv[]) ...@@ -445,7 +475,7 @@ int run(int argc, char* argv[])
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 * idx[1], idx[3], idx[2]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
if(PRINT_HOST) if(PRINT_HOST)
...@@ -456,44 +486,6 @@ int run(int argc, char* argv[]) ...@@ -456,44 +486,6 @@ int run(int argc, char* argv[])
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o; std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
} }
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if 0
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
// Gradients // Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
...@@ -613,7 +605,10 @@ int run(int argc, char* argv[]) ...@@ -613,7 +605,10 @@ int run(int argc, char* argv[])
kgrad_gs_ns_ks_host_result.mData); kgrad_gs_ns_ks_host_result.mData);
std::cout << "Checking vgrad:\n"; std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData, pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData); vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -50,7 +50,8 @@ template <index_t BlockSize, ...@@ -50,7 +50,8 @@ template <index_t BlockSize,
index_t NPerXDL, index_t NPerXDL,
index_t MRepeat, index_t MRepeat,
index_t NRepeat, index_t NRepeat,
index_t KPack> index_t KPack,
bool TransposeC = false>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -72,7 +73,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{}; static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
...@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -226,6 +227,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N)); make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{ {
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
......
...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax ...@@ -108,6 +108,24 @@ struct BlockwiseSoftmax
}); });
} }
template <typename CThreadBuffer, typename LSEBuffer>
__host__ __device__ void RunWithPreCalcStats(CThreadBuffer& in_thread_buf,
const LSEBuffer& lse_thread_buf)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
static_for<0, MRepeat, 1>{}([&](auto iM) {
static_for<0, KRepeat, 1>{}([&](auto iK) {
auto offset = Number<ThreadSliceDesc_M_K{}.CalculateOffset(make_tuple(iM, iK))>{};
in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset])
? 0
: math::exp(in_thread_buf[offset] - lse_thread_buf[iM]);
});
});
}
BufferType max_value_buf; BufferType max_value_buf;
BufferType sum_value_buf; BufferType sum_value_buf;
}; };
......
...@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 ...@@ -137,6 +137,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
static_for<0, DstScalarPerVector, 1>{}([&](auto i) { static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset( constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// Sequence<num_access, idx_1d.value, i.value, src_offset>{}.foo();
SrcData v; SrcData v;
......
...@@ -28,8 +28,8 @@ struct SubThreadBlock ...@@ -28,8 +28,8 @@ struct SubThreadBlock
__device__ static constexpr index_t GetNumOfThread() { return kNumThread_; } __device__ static constexpr index_t GetNumOfThread() { return kNumThread_; }
template <typename Tuple2> template <typename TupleArg1, typename TupleArg2>
__device__ constexpr bool IsBelong(const Tuple2& mwave_range, const Tuple2& nwave_range) __device__ constexpr bool IsBelong(const TupleArg1& mwave_range, const TupleArg2& nwave_range)
{ {
// wave_range[I0] inclusive, wave_range[I1] exclusive // wave_range[I0] inclusive, wave_range[I1] exclusive
if(mwave_ < mwave_range[I0]) if(mwave_ < mwave_range[I0])
......
...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator ...@@ -149,6 +149,13 @@ struct ReferenceSoftmax : public device::BaseOperator
ck::type_convert<AccDataType>( ck::type_convert<AccDataType>(
arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) + arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))) +
arg.beta_ * self(idx); arg.beta_ * self(idx);
// printf(
// "exponent %f, exp() = %f\n",
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx))),
// std::exp(
// ck::type_convert<AccDataType>(arg.in_(idx)) -
// ck::type_convert<AccDataType>(arg.sm_stats_ptr_[0](to_sm_stats_idx(idx)))));
}); });
return 0; return 0;
......
...@@ -148,7 +148,7 @@ check_err(const Range& out, ...@@ -148,7 +148,7 @@ check_err(const Range& out,
{ {
max_err = err > max_err ? err : max_err; max_err = err > max_err ? err : max_err;
err_count++; err_count++;
if(err_count < 5) if(err_count < 32)
{ {
std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
add_test_executable(test_space_filling_curve space_filling_curve.cpp) add_test_executable(test_space_filling_curve space_filling_curve.cpp)
add_test_executable(test_threadwise_copy test_threadwise_copy.cpp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment