"configs/vscode:/vscode.git/clone" did not exist on "343f785b07a90c4ef4df5621544b210641f56f81"
Commit 7b01d8f5 authored by danyao12's avatar danyao12
Browse files

remove comments

parent 4d720be3
...@@ -531,32 +531,8 @@ int run(int argc, char* argv[]) ...@@ -531,32 +531,8 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// z_gs_ms_ns.ForEach(
// [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
// run_attention_fwd_host(q_g_m_k,
// k_g_n_k,
// v_g_n_o,
// alpha,
// s_g_m_n,
// p_g_m_n,
// y_g_m_o,
// lse_g_m,
// p_drop_g_m_n,
// z_g_m_n,
// p_dropout_in_16bits,
// rp_dropout);
// y_gs_ms_os.ForEach(
// [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
// });
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -574,11 +550,7 @@ int run(int argc, char* argv[]) ...@@ -574,11 +550,7 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
// y_device_buf.ToDevice(y_gs_ms_os.mData.data());
// lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
......
...@@ -531,32 +531,8 @@ int run(int argc, char* argv[]) ...@@ -531,32 +531,8 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_gs_ns_ks.ForEach(
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
// z_gs_ms_ns.ForEach(
// [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
v_gs_os_ns.ForEach( v_gs_os_ns.ForEach(
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
// run_attention_fwd_host(q_g_m_k,
// k_g_n_k,
// v_g_n_o,
// alpha,
// s_g_m_n,
// p_g_m_n,
// y_g_m_o,
// lse_g_m,
// p_drop_g_m_n,
// z_g_m_n,
// p_dropout_in_16bits,
// rp_dropout);
// y_gs_ms_os.ForEach(
// [&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
// });
// lse_gs_ms.ForEach(
// [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -574,11 +550,7 @@ int run(int argc, char* argv[]) ...@@ -574,11 +550,7 @@ int run(int argc, char* argv[])
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data()); z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
// y_device_buf.ToDevice(y_gs_ms_os.mData.data());
// lse_device_buf.ToDevice(lse_gs_ms.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
// kgrad_device_buf.SetZero();
// vgrad_device_buf.SetZero();
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment