Commit b76c8e62 authored by letaoqin's avatar letaoqin
Browse files

add type check

parent 4e6fd810
...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1153,10 +1153,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1153,10 +1153,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
...@@ -1190,10 +1190,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1190,10 +1190,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
...@@ -1336,10 +1336,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1336,10 +1336,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
...@@ -1408,10 +1408,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1408,10 +1408,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
...@@ -1182,10 +1182,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1182,10 +1182,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
...@@ -1253,10 +1253,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1253,10 +1253,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2 // saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if(KzRaw % 2 != 0) if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{ {
std::cout << "K_q must be a multiple of 2" << std::endl; if(KzRaw % 2 != 0)
return false; {
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
} }
// Check vector load/store requirement // Check vector load/store requirement
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment