Commit b76c8e62 authored by letaoqin's avatar letaoqin
Browse files

add type check

parent 4e6fd810
......@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1153,10 +1153,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
......@@ -1190,10 +1190,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
......@@ -1336,10 +1336,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
......@@ -1408,10 +1408,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
......@@ -1182,10 +1182,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
......@@ -1253,10 +1253,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0)
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment