"...composable_kernel_rocm.git" did not exist on "836b7e557d028cc2d7c6b341352253fd81003e54"
Commit b1e544e2 authored by Anthony Chang's avatar Anthony Chang
Browse files

ready to plug in kernel

parent 4f6d52c1
......@@ -15,6 +15,8 @@ Outputs:
*/
#define PRINT_HOST 1
#include <iostream>
#include <numeric>
#include <initializer_list>
......@@ -97,38 +99,60 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough,
PassThrough,
Scale>;
template <typename TensorQ,
typename TensorK,
typename TensorV,
typename TensorS,
typename TensorP,
typename TensorY>
void run_attention_fwd_host(const TensorQ& q_g_m_k,
const TensorK& k_g_n_k,
const TensorV& v_g_n_o,
const float alpha,
TensorS& s_g_m_n,
TensorP& p_g_m_n,
TensorY& y_g_m_o)
{
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
q_g_m_k, k_g_k_n, s_g_m_n, PassThrough{}, PassThrough{}, Scale{alpha});
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
#if 0
// Ref Gemm dP: dP = dY * V^T
// fp16 in, fp16 out
using ReferenceGemmPGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
// Ref Gemm dQ: dQ = alpha * dS * K
// fp16 in, fp16 out
using ReferenceGemmQGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref Gemm dK: dK = alpha * dS^T * Q
// fp16 in, fp16 out
using ReferenceGemmKGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
DataType,
DataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
});
#endif
// P = Softmax(S)
// >>> scipy.special.softmax(numpy.eye(4), 1)
// array([[0.47536689, 0.1748777 , 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.47536689, 0.1748777 , 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.47536689, 0.1748777 ],
// [0.1748777 , 0.1748777 , 0.1748777 , 0.47536689]])
auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(s_g_m_n, p_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument);
// Y = P * V
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
p_g_m_n, v_g_n_o, y_g_m_o, PassThrough{}, PassThrough{}, PassThrough{});
ref_gemm1_invoker.Run(ref_gemm1_argument);
}
int run(int argc, char* argv[])
{
bool do_verification = true;
......@@ -136,9 +160,9 @@ int run(int argc, char* argv[])
bool time_kernel = false;
// Overall QKV matrices shape
// Y_g_m_o = Softmax(Q_g_m_k * K_g_k_n) * V_g_n_o
// Y_g0_g1_m_o = reshape(Y_g_m_o, [G0, G1, M, O])
// Y_g0_m_g1_o = permute(Y_g0_g1_m_o, [0, 2, 1, 3])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 4;
ck::index_t N = 4;
ck::index_t K = 4;
......@@ -219,20 +243,13 @@ int run(int argc, char* argv[])
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
// Tensor<DataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> qgrad_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
// Tensor<DataType> qgrad_gs_ms_ks_device(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_device(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device(v_gs_os_ns_lengths, v_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
switch(init_method)
{
......@@ -262,15 +279,38 @@ int run(int argc, char* argv[])
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{10});
}
#if 0
// calculate y beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o);
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
#endif
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
ygrad_device_buf.ToDevice(y_gs_ms_os.mData.data());
// TODO ANT: attention backward kernel
#if 0
......@@ -280,7 +320,7 @@ int run(int argc, char* argv[])
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths,
......@@ -323,44 +363,31 @@ int run(int argc, char* argv[])
<< gemm.GetTypeString() << std::endl;
#endif
bool pass = true;
if(do_verification)
{
Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<DataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<DataType> p_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<DataType> y_g_m_o({BatchCount, M, O}); // scratch object after gemm1
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); // scratch object in bwd pass
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); // scratch object in bwd pass
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); // scratch object in bwd pass
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
// permute
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); // TODO ANT: os_ns -> ns_os ?
});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx){
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 * idx[1], idx[3], idx[2]) = self(idx);
});
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
if(PRINT_HOST)
{
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "v_g_n_o ref:\n" << v_g_n_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
}
// S = alpha * Q * K^T
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto k_g_k_n = k_g_n_k.Transpose({0, 2, 1});
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
......@@ -400,77 +427,126 @@ int run(int argc, char* argv[])
// Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
if(PRINT_HOST)
{
std::cout << "===== dP = dY * V^T\n";
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "v_g_o_n ref:\n" << v_g_o_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
}
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn){
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for (int o = 0; o < O; o++)
for(int o = 0; o < O; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
}
self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
});
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "y_g_m_o ref:\n" << y_g_m_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
if(PRINT_HOST)
{
std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
std::cout << "p_g_m_n ref:\n" << p_g_m_n;
std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
std::cout << "y_g_m_o ref:\n" << y_g_m_o;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
}
// dV = P^T * dY
auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
if(PRINT_HOST)
{
std::cout << "===== dV = P^T * dY\n";
std::cout << "p_g_n_m ref:\n" << p_g_n_m;
std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
}
// dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
if(PRINT_HOST)
{
std::cout << "===== dQ = alpha * dS * K\n";
std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
std::cout << "k_g_n_k ref:\n" << k_g_n_k;
std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
}
// dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
if(PRINT_HOST)
{
std::cout << "===== dK = alpha * dS^T * Q\n";
std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
std::cout << "q_g_m_k ref:\n" << q_g_m_k;
std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
}
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// permute
// y_gs_ms_os.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1 + g1;
// self(idx) = y_g_m_o(g, idx[2], idx[3]);
// });
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
// return ck::utils::check_err(y_gs_ms_os_device_result.mData, y_gs_ms_os.mData)
// ? 0
// : 1;
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData);
}
return 0;
return pass ? 0 : 1;
}
int main(int argc, char* argv[]) { return run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment