Commit 1a1a8924 authored by fsx950223's avatar fsx950223
Browse files

format code

parent e327363f
......@@ -399,6 +399,8 @@ int run(int argc, char* argv[])
std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<DataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors;
......@@ -420,7 +422,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3;
std::size_t group_count = 1;
std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++)
{
......@@ -629,6 +631,7 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os);
p_drop_g_m_ns.push_back(p_drop_g_m_n);
q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back(
......@@ -721,36 +724,36 @@ int run(int argc, char* argv[])
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
// p_z = std::vector<void*>(p_z.size(), nullptr);
// argument =
// gemm.MakeArgument(p_q,
// p_k,
// p_z,
// p_v,
// p_y,
// p_lse,
// p_ygrad,
// p_qgrad,
// p_kgrad,
// p_vgrad,
// {}, // std::array<void*, 1> p_acc0_biases;
// {}, // std::array<void*, 1> p_acc1_biases;
// problem_descs,
// QKVElementOp{},
// QKVElementOp{},
// Scale{alpha},
// QKVElementOp{},
// YElementOp{},
// p_drop,
// std::tuple<unsigned long long, unsigned long long>(seed, offset));
// DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
// if(!gemm.IsSupportedArgument(argument))
// {
// std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
// return 0;
// }
p_z = std::vector<void*>(p_z.size(), nullptr);
argument =
gemm.MakeArgument(p_q,
p_k,
p_z,
p_v,
p_y,
p_lse,
p_ygrad,
p_qgrad,
p_kgrad,
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
problem_descs,
QKVElementOp{},
QKVElementOp{},
Scale{alpha},
QKVElementOp{},
YElementOp{},
p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++)
{
......@@ -767,8 +770,9 @@ int run(int argc, char* argv[])
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
......@@ -778,7 +782,13 @@ int run(int argc, char* argv[])
// dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker();
auto ref_dropout_argment = ref_dropout.MakeArgument(
z_g_m_ns[i], pgrad_drop_g_m_n, pgrad_g_m_n, p_dropout_in_16bits, rp_dropout);
ref_dropout_invoker.Run(ref_dropout_argment);
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
......@@ -789,9 +799,9 @@ int run(int argc, char* argv[])
}
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
});
auto p_g_n_m = p_g_m_ns[i].Transpose({0, 2, 1});
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment