Commit b76c8e62 authored by letaoqin's avatar letaoqin
Browse files

add type check

parent 4e6fd810
...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1153,11 +1153,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1153,11 +1153,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
......
...@@ -1190,11 +1190,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1190,11 +1190,15 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
......
...@@ -1336,11 +1336,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1336,11 +1336,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
......
...@@ -1408,11 +1408,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1408,11 +1408,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
......
...@@ -1182,11 +1182,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1182,11 +1182,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
......
...@@ -1253,11 +1253,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1253,11 +1253,15 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0) if(KzRaw % 2 != 0)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; std::cout << "K_q must be a multiple of 2" << std::endl;
return false; return false;
} }
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment