Commit 067e71a8 authored by guangzlu's avatar guangzlu
Browse files

added dropout verify for grouped mha fp16 fwd

parent 937fcc07
...@@ -103,8 +103,8 @@ using DeviceGemmInstance = ...@@ -103,8 +103,8 @@ using DeviceGemmInstance =
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 64, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -112,7 +112,7 @@ using DeviceGemmInstance = ...@@ -112,7 +112,7 @@ using DeviceGemmInstance =
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
......
...@@ -5,7 +5,7 @@ int run(int argc, char* argv[]) ...@@ -5,7 +5,7 @@ int run(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = true; bool time_kernel = false;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -45,7 +45,7 @@ int run(int argc, char* argv[]) ...@@ -45,7 +45,7 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float alpha = 0.25; // scaling after 1st gemm float alpha = 1; // scaling after 1st gemm
std::size_t group_count = 8; std::size_t group_count = 8;
...@@ -76,27 +76,17 @@ int run(int argc, char* argv[]) ...@@ -76,27 +76,17 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
std::cout << "group count " << group_count << ". printing first 4 groups\n"; // std::cout << "group count " << group_count << ". printing first 4 groups\n";
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 512; int M = 128 * (rand() % 8 + 1);
int N = 512; int N = 128 * (rand() % 8 + 1);
int K = 40; int K = 128;
int O = 40; int O = 128;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1 = rand() % 5 + 1;
// int M = 128 * (rand() % 8 + 1);
// int N = 128 * (rand() % 8 + 1);
// int K = 40;
// int O = 40 * (rand() % 2 + 1);
// int G0 = rand() % 3 + 1;
// int G1 = rand() % 5 + 1;
std::cout << "group id" << i << " M, N, K, O, G0, G1 is " << M << "," << N << "," << K
<< "," << O << "," << G0 << "," << G1 << std::endl;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
...@@ -322,7 +312,6 @@ int run(int argc, char* argv[]) ...@@ -322,7 +312,6 @@ int run(int argc, char* argv[])
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
// Tensor<CDataType> z_gs_ms_ns_host_result(z_gs_ms_os_lengths, z_gs_ms_os_strides);
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1 Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
...@@ -416,11 +405,10 @@ int run(int argc, char* argv[]) ...@@ -416,11 +405,10 @@ int run(int argc, char* argv[])
atol = 1e-2; atol = 1e-2;
} }
printf("group id is %lu \n", i);
// bool pass_ = // bool pass_ =
// ck::utils::check_err(c_gs_ms_os_device_result.mData, // ck::utils::check_err(c_gs_ms_os_device_result.mData,
// c_gs_ms_os_host_result.mData); // c_gs_ms_os_host_result.mData);
bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData, bool pass_ = ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData, c_gs_ms_os_host_result.mData,
"Error: Incorrect results c!", "Error: Incorrect results c!",
...@@ -433,6 +421,10 @@ int run(int argc, char* argv[]) ...@@ -433,6 +421,10 @@ int run(int argc, char* argv[])
atol); atol);
pass &= pass_; pass &= pass_;
} }
if(pass)
{
std::cout << "Verification passed." << std::endl;
}
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -273,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -273,6 +273,12 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{ {
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment