"vscode:/vscode.git/clone" did not exist on "866324b9a5bea99a74a81172704c67008d7cb9fa"
Commit 067e71a8 authored by guangzlu's avatar guangzlu
Browse files

added dropout verify for grouped mha fp16 fwd

parent 937fcc07
......@@ -103,8 +103,8 @@ using DeviceGemmInstance =
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
64, // Gemm1NPerBlock
32, // Gemm1KPerBlock
128, // Gemm1NPerBlock
64, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
......@@ -112,7 +112,7 @@ using DeviceGemmInstance =
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
2, // Gemm1NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
......
......@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
bool time_kernel = false;
bool input_permute = false;
bool output_permute = true;
......@@ -45,7 +45,7 @@ int run(int argc, char* argv[])
exit(0);
}
float alpha = 0.25; // scaling after 1st gemm
float alpha = 1; // scaling after 1st gemm
std::size_t group_count = 8;
......@@ -76,27 +76,17 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0;
std::cout << "group count " << group_count << ". printing first 4 groups\n";
// std::cout << "group count " << group_count << ". printing first 4 groups\n";
for(std::size_t i = 0; i < group_count; i++)
{
int M = 512;
int N = 512;
int K = 40;
int O = 40;
int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1);
int K = 128;
int O = 128;
int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1;
// int M = 128 * (rand() % 8 + 1);
// int N = 128 * (rand() % 8 + 1);
// int K = 40;
// int O = 40 * (rand() % 2 + 1);
// int G0 = rand() % 3 + 1;
// int G1 = rand() % 5 + 1;
std::cout << "group id" << i << " M, N, K, O, G0, G1 is " << M << "," << N << "," << K
<< "," << O << "," << G0 << "," << G1 << std::endl;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
......@@ -322,7 +312,6 @@ int run(int argc, char* argv[])
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
// Tensor<CDataType> z_gs_ms_ns_host_result(z_gs_ms_os_lengths, z_gs_ms_os_strides);
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
......@@ -416,11 +405,10 @@ int run(int argc, char* argv[])
atol = 1e-2;
}
printf("group id is %lu \n", i);
// bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData);
bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!",
......@@ -433,6 +421,10 @@ int run(int argc, char* argv[])
atol);
pass &= pass_;
}
if(pass)
{
std::cout << "Verification passed." << std::endl;
}
}
return pass ? 0 : 1;
......
......@@ -273,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment