Commit ccf94638 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Pass 4d sequence and convert to 3d

parent 860433ea
......@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
......
......@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
32, // NPerXdl
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1
......
......@@ -49,24 +49,24 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl
4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ComputeTypeA, // ComputeTypeA
......
......@@ -315,59 +315,92 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
batch);
}
template <typename SeqType>
constexpr static auto
ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Remove first element and,
// Convert 4d->3d sequence.
constexpr auto _I0 = SeqType{}.At(I1);
constexpr auto _I1 = SeqType{}.At(I2);
constexpr auto _I2 = SeqType{}.At(I0);
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
template <typename SeqType>
constexpr static auto
TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Skip first element and
// Convert 4d->3d and take away one from seq.
constexpr index_t one = 1;
constexpr auto _I0 = SeqType{}.At(I1) - one;
constexpr auto _I1 = SeqType{}.At(I2) - one;
constexpr auto _I2 = SeqType{}.At(I3) - one;
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
decltype(ShuffleSequenceAndTransformFrom4DTo3D<
ABlockTransferThreadClusterLengths_K0_M_K1>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
ABlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<ABlockTransferSrcAccessOrder>()),
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
decltype(ShuffleSequenceAndTransformFrom4DTo3D<
BBlockTransferThreadClusterLengths_K0_N_K1>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
BBlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<BBlockTransferSrcAccessOrder>()),
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
......
......@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
16, // NPerXdl
1, // MXdlPerWave
1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment