Unverified Commit 18be6bc9 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #980 from ROCmSoftwarePlatform/mha-train-develop-fix-issupport

fix mha bwd  IsSupportedArgument
parents f27f9158 422a69b2
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1152,6 +1152,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1152,6 +1152,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1189,6 +1189,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1189,6 +1189,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
...@@ -1335,6 +1335,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1335,6 +1335,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
......
...@@ -1407,6 +1407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1407,6 +1407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
......
...@@ -1181,6 +1181,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1181,6 +1181,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
......
...@@ -1239,6 +1239,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1239,6 +1239,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false; return false;
} }
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
......
...@@ -88,6 +88,9 @@ template <typename InputDataType, ...@@ -88,6 +88,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -96,6 +96,10 @@ template <typename InputDataType, ...@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0); static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -87,6 +87,9 @@ template <typename InputDataType, ...@@ -87,6 +87,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock); static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -95,6 +95,10 @@ template <typename InputDataType, ...@@ -95,6 +95,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0); static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0); static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0); static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
...@@ -97,6 +97,10 @@ template <typename FloatAB, ...@@ -97,6 +97,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
......
...@@ -88,6 +88,10 @@ template <typename FloatAB, ...@@ -88,6 +88,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{ {
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 || static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 || D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4, D0BlockTransferSrcScalarPerVector == 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment