Commit 1a1a8924 authored by fsx950223's avatar fsx950223
Browse files

format code

parent e327363f
...@@ -399,6 +399,8 @@ int run(int argc, char* argv[]) ...@@ -399,6 +399,8 @@ int run(int argc, char* argv[])
std::vector<Tensor<AccDataType>> s_g_m_ns; std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns; std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os; std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<DataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors; std::vector<Tensor<DataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors; std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors; std::vector<Tensor<DataType>> v_tensors;
...@@ -420,7 +422,7 @@ int run(int argc, char* argv[]) ...@@ -420,7 +422,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3; std::size_t group_count = 1;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -629,6 +631,7 @@ int run(int argc, char* argv[]) ...@@ -629,6 +631,7 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
p_drop_g_m_ns.push_back(p_drop_g_m_n);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -721,36 +724,36 @@ int run(int argc, char* argv[]) ...@@ -721,36 +724,36 @@ int run(int argc, char* argv[])
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
} }
// p_z = std::vector<void*>(p_z.size(), nullptr); p_z = std::vector<void*>(p_z.size(), nullptr);
// argument = argument =
// gemm.MakeArgument(p_q, gemm.MakeArgument(p_q,
// p_k, p_k,
// p_z, p_z,
// p_v, p_v,
// p_y, p_y,
// p_lse, p_lse,
// p_ygrad, p_ygrad,
// p_qgrad, p_qgrad,
// p_kgrad, p_kgrad,
// p_vgrad, p_vgrad,
// {}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
// {}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
// problem_descs, problem_descs,
// QKVElementOp{}, QKVElementOp{},
// QKVElementOp{}, QKVElementOp{},
// Scale{alpha}, Scale{alpha},
// QKVElementOp{}, QKVElementOp{},
// YElementOp{}, YElementOp{},
// p_drop, p_drop,
// std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
// DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
// if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
// { {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0; return 0;
// } }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
...@@ -767,8 +770,9 @@ int run(int argc, char* argv[]) ...@@ -767,8 +770,9 @@ int run(int argc, char* argv[])
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) { ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
...@@ -778,7 +782,13 @@ int run(int argc, char* argv[]) ...@@ -778,7 +782,13 @@ int run(int argc, char* argv[])
// dP = dY * V^T // dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) { sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0; float ygrad_dot_y = 0;
for(int o = 0; o < O; o++) for(int o = 0; o < O; o++)
...@@ -789,9 +799,9 @@ int run(int argc, char* argv[]) ...@@ -789,9 +799,9 @@ int run(int argc, char* argv[])
} }
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y); self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
}); });
auto p_g_n_m = p_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment