"example/vscode:/vscode.git/clone" did not exist on "c3fdcafbf400d244e4fa950d04c11512888c0ca4"
Commit 3185cbf9 authored by fsx950223's avatar fsx950223
Browse files

add verification

parent fe6ee651
......@@ -244,7 +244,7 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float K = 64;
float K = 128;
float alpha = 1.f / std::sqrt(K);
bool input_permute = false;
......@@ -297,6 +297,12 @@ int run(int argc, char* argv[])
std::vector<DataType*> p_vgrad;
std::vector<const DataType*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks;
std::vector<Tensor<DataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os;
std::vector<Tensor<DataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors;
......@@ -478,14 +484,20 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k, k_g_n_k, v_g_n_o, alpha, s_g_m_n, p_g_m_n, y_g_m_o, lse_g_m);
y_gs_ms_os.ForEach(
[&](auto& self, auto idx) { self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
q_g_m_ks.push_back(q_g_m_k);
k_g_n_ks.push_back(k_g_n_k);
v_g_n_os.push_back(v_g_n_o);
s_g_m_ns.push_back(s_g_m_n);
p_g_m_ns.push_back(p_g_m_n);
y_g_m_os.push_back(y_g_m_o);
q_tensors.push_back(q_gs_ms_ks);
k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns);
......@@ -566,172 +578,119 @@ int run(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
// bool pass = true;
// if(do_verification)
// {
// kgrad_device_buf.SetZero(); // reset global accum buffer and rerun
// vgrad_device_buf.SetZero();
// invoker.Run(argument, StreamConfig{nullptr, false});
// Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
// Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
// Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
// Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
// Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
// Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
// ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
// ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
// });
// #if PRINT_HOST
// {
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
// std::cout << "v_g_n_o ref:\n" << v_g_n_o;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// }
// #endif
// // Gradients
// auto ref_gemm_grad = ReferenceGemmGradInstance{};
// auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
// using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// // dP = dY * V^T
// auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST
// {
// std::cout << "===== dP = dY * V^T\n";
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "v_g_o_n ref:\n" << v_g_o_n;
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// }
// #endif
// // dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
// sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
// float ygrad_dot_y = 0;
// for(int o = 0; o < O; o++)
// {
// auto idx_gmo = idx_gmn;
// idx_gmo[2] = o;
// ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_o(idx_gmo);
// }
// self(idx_gmn) = p_g_m_n(idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
// });
// #if PRINT_HOST
// {
// std::cout << "===== dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)\n";
// std::cout << "p_g_m_n ref:\n" << p_g_m_n;
// std::cout << "pgrad_g_m_n ref:\n" << pgrad_g_m_n;
// std::cout << "y_g_m_o ref:\n" << y_g_m_o;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// }
// #endif
// // dV = P^T * dY
// auto p_g_n_m = p_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
// #if PRINT_HOST
// {
// std::cout << "===== dV = P^T * dY\n";
// std::cout << "p_g_n_m ref:\n" << p_g_n_m;
// std::cout << "ygrad_g_m_o ref:\n" << ygrad_g_m_o;
// std::cout << "vgrad_g_n_o ref:\n" << vgrad_g_n_o;
// }
// #endif
// // dQ = alpha * dS * K
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// #if PRINT_HOST
// {
// std::cout << "===== dQ = alpha * dS * K\n";
// std::cout << "sgrad_g_m_n ref:\n" << sgrad_g_m_n;
// std::cout << "k_g_n_k ref:\n" << k_g_n_k;
// std::cout << "qgrad_g_m_k ref:\n" << qgrad_g_m_k;
// }
// #endif
// // dK = alpha * dS^T * Q
// auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
// ref_gemm_grad_invoker.Run(RefGemmGradArg{
// sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
// #if PRINT_HOST
// {
// std::cout << "===== dK = alpha * dS^T * Q\n";
// std::cout << "sgrad_g_n_m ref:\n" << sgrad_g_n_m;
// std::cout << "q_g_m_k ref:\n" << q_g_m_k;
// std::cout << "kgrad_g_n_k ref:\n" << kgrad_g_n_k;
// }
// #endif
// Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
// Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
// Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
// Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
// qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
// kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
// vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
// // permute
// qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
// });
// kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
// });
// vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
// const size_t& g0 = idx[0];
// const size_t& g1 = idx[1];
// const size_t g = g0 * G1 + g1;
// self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
// });
// std::cout << "Checking qgrad:\n";
// pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
// qgrad_gs_ms_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking kgrad:\n";
// pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
// kgrad_gs_ns_ks_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// std::cout << "Checking vgrad:\n";
// pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
// vgrad_gs_os_ns_host_result.mData,
// "error",
// 1e-2,
// 1e-2);
// }
// return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
return 0;
bool pass = true;
if(do_verification)
{
for(int i=0;i<group_count;i++){
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i=0; i<group_count; i++){
int G0 = v_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm_grad = ReferenceGemmGradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument;
// dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
ygrad_g_m_o, v_g_o_n, pgrad_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
sgrad_g_m_n.ForEach([&](auto& self, auto idx_gmn) {
float ygrad_dot_y = 0;
for(int o = 0; o < O; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
ygrad_dot_y += ygrad_g_m_o(idx_gmo) * y_g_m_os[i](idx_gmo);
}
self(idx_gmn) = p_g_m_ns[i](idx_gmn) * (pgrad_g_m_n(idx_gmn) - ygrad_dot_y);
});
auto p_g_n_m = p_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
p_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.f}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking kgrad:\n";
pass &= ck::utils::check_err(kgrad_gs_ns_ks_device_result.mData,
kgrad_gs_ns_ks_host_result.mData,
"error",
1e-2,
1e-2);
std::cout << "Checking vgrad:\n";
pass &= ck::utils::check_err(vgrad_gs_os_ns_device_result.mData,
vgrad_gs_os_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
}
int main(int argc, char* argv[]) { return run(argc, argv); }
......@@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
index_t left = 0;
......@@ -718,9 +718,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -844,31 +844,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
// ignore = acc1_biases_gs_ms_gemm1ns_strides;
}
// void Print() const
// {
// std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", "
// << a_grid_desc_g_m_k_.GetLength(I1) << ", "
// << a_grid_desc_g_m_k_.GetLength(I2) << '\n';
// // a_grid_desc_g_m_k_.Print();
// std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b_grid_desc_g_n_k_.Print();
// std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
// << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// // b1_grid_desc_g_n_k_.Print();
// std::cout << "c_grid_desc_g_m_o_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", "
// << c_grid_desc_g_m_n_.GetLength(I1) << ", "
// << c_grid_desc_g_m_n_.GetLength(I2) << '\n';
// // c_grid_desc_g_m_n_.Print();
// std::cout << "vgrad_grid_desc_n_o_: " << vgrad_grid_desc_n_o_.GetLength(I0) << ", "
// << vgrad_grid_desc_n_o_.GetLength(I1) << '\n';
// std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
// << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
// << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
// }
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -914,15 +889,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment