Commit 768a05a5 authored by letaoqin's avatar letaoqin
Browse files

add check that K_q must be a multiple of 2

parent b23b3d71
......@@ -1152,6 +1152,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1189,6 +1189,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1181,6 +1181,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
......@@ -1252,6 +1252,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment