Commit 6b8957a0 authored by fsx950223's avatar fsx950223
Browse files

fix a bug

parent 8ef97116
......@@ -274,11 +274,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
if(Gemm1N != K)
{
std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false;
}
// if(Gemm1N != K)
// {
// std::cout << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
// return false;
// }
if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment