Commit 4d720be3 authored by danyao12's avatar danyao12
Browse files

remove unnecessary host run

parent 4d140b5d
...@@ -50,8 +50,8 @@ template <ck::index_t... Is> ...@@ -50,8 +50,8 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -160,7 +160,7 @@ using DeviceGemmInstance = ...@@ -160,7 +160,7 @@ using DeviceGemmInstance =
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else #else
//2nd template // 2nd template
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
...@@ -531,30 +531,32 @@ int run(int argc, char* argv[]) ...@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach( // z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); // [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach( // lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); // [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, // run_attention_fwd_host(q_g_m_k,
k_g_n_k, // k_g_n_k,
v_g_n_o, // v_g_n_o,
alpha, // alpha,
s_g_m_n, // s_g_m_n,
p_g_m_n, // p_g_m_n,
y_g_m_o, // y_g_m_o,
lse_g_m, // lse_g_m,
p_drop_g_m_n, // p_drop_g_m_n,
z_g_m_n, // z_g_m_n,
p_dropout_in_16bits, // p_dropout_in_16bits,
rp_dropout); // rp_dropout);
y_gs_ms_os.ForEach( // y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); // [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
lse_gs_ms.ForEach( // });
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); // lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -572,11 +574,11 @@ int run(int argc, char* argv[]) ...@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); // y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); // lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); // vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -708,7 +710,10 @@ int run(int argc, char* argv[]) ...@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
// call kernel again // call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
...@@ -768,9 +773,12 @@ int run(int argc, char* argv[]) ...@@ -768,9 +773,12 @@ int run(int argc, char* argv[])
{ {
auto idx_gmo = idx_gmn; auto idx_gmo = idx_gmn;
idx_gmo[2] = o; idx_gmo[2] = o;
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) * ck::type_convert<AccDataType>(y_g_m_o(idx_gmo)); ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = ck::type_convert<DataType>(ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) * (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); self(idx_gmn) = ck::type_convert<DataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
#if PRINT_HOST #if PRINT_HOST
{ {
......
...@@ -344,8 +344,8 @@ int run(int argc, char* argv[]) ...@@ -344,8 +344,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512; // 512 ck::index_t N = 512; // 512
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 64; ck::index_t O = 64;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 54; // 54
ck::index_t G1 = 6; // 16 ck::index_t G1 = 16; // 16
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
...@@ -531,30 +531,32 @@ int run(int argc, char* argv[]) ...@@ -531,30 +531,32 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
z_gs_ms_ns.ForEach( // z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); // [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach( // lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); }); // [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, // run_attention_fwd_host(q_g_m_k,
k_g_n_k, // k_g_n_k,
v_g_n_o, // v_g_n_o,
alpha, // alpha,
s_g_m_n, // s_g_m_n,
p_g_m_n, // p_g_m_n,
y_g_m_o, // y_g_m_o,
lse_g_m, // lse_g_m,
p_drop_g_m_n, // p_drop_g_m_n,
z_g_m_n, // z_g_m_n,
p_dropout_in_16bits, // p_dropout_in_16bits,
rp_dropout); // rp_dropout);
y_gs_ms_os.ForEach( // y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); }); // [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
lse_gs_ms.ForEach( // });
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); // lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -572,11 +574,11 @@ int run(int argc, char* argv[]) ...@@ -572,11 +574,11 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); // y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); // lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
kgrad_device_buf.SetZero(); // kgrad_device_buf.SetZero();
vgrad_device_buf.SetZero(); // vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -708,7 +710,10 @@ int run(int argc, char* argv[]) ...@@ -708,7 +710,10 @@ int run(int argc, char* argv[])
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
// call kernel again // call kernel again
kgrad_device_buf.SetZero(); // reset global accum buffer and rerun kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment