Commit 6232c36c authored by guangzlu's avatar guangzlu
Browse files

modified mha fwd check validity in gridwise

parent 773247db
......@@ -273,11 +273,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
// if(Gemm1N != K)
//{
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false;
//}
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment