Unverified Commit 18be6bc9 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #980 from ROCmSoftwarePlatform/mha-train-develop-fix-issupport

fix mha bwd  IsSupportedArgument
parents f27f9158 422a69b2
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 64 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......
......@@ -1293,6 +1293,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1325,6 +1325,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1152,6 +1152,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1189,6 +1189,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest =
ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
......
......@@ -1335,6 +1335,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
......@@ -1407,6 +1407,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
......@@ -1181,6 +1181,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
......@@ -1239,6 +1239,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return false;
}
// saving dQ data with atomic_add instruction, so KzRaw must be a multiple of 2
if constexpr(is_same<OutputDataType, half_t>::value ||
is_same<OutputDataType, bhalf_t>::value)
{
if(KzRaw % 2 != 0)
{
std::cout << "K_q must be a multiple of 2" << std::endl;
return false;
}
}
// Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1]
......
......@@ -88,6 +88,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -87,6 +87,9 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -95,6 +95,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
......
......@@ -97,6 +97,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
......
......@@ -88,6 +88,10 @@ template <typename FloatAB,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseMultiHeadFlashAttentionInfer_Xdl_CShuffle
{
static_assert(AK1Value % ABlockTransferDstScalarPerVector_AK1 == 0);
static_assert(BK1Value % BBlockTransferDstScalarPerVector_BK1 == 0);
static_assert(B1K1Value % B1BlockTransferDstScalarPerVector_BK1 == 0);
static_assert(D0BlockTransferSrcScalarPerVector == 1 ||
D0BlockTransferSrcScalarPerVector == 2 ||
D0BlockTransferSrcScalarPerVector == 4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment