Commit ccf94638 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Pass 4d sequence and convert to 3d

parent 860433ea
...@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance = ...@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
......
...@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance = ...@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
......
...@@ -49,24 +49,24 @@ using DeviceConvBwdWeightInstance = ...@@ -49,24 +49,24 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
false, // BBlockLdsAddExtraN false, // BBlockLdsAddExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 8, 1, 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CBlockTransferScalarPerVector_NWaveNPerXdl 2, // CBlockTransferScalarPerVector_NWaveNPerXdl
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ComputeTypeA, // ComputeTypeA ComputeTypeA, // ComputeTypeA
......
...@@ -315,59 +315,92 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -315,59 +315,92 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
batch); batch);
} }
template <typename SeqType>
constexpr static auto
ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Remove first element and,
// Convert 4d->3d sequence.
constexpr auto _I0 = SeqType{}.At(I1);
constexpr auto _I1 = SeqType{}.At(I2);
constexpr auto _I2 = SeqType{}.At(I0);
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
template <typename SeqType>
constexpr static auto
TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Skip first element and
// Convert 4d->3d and take away one from seq.
constexpr index_t one = 1;
constexpr auto _I0 = SeqType{}.At(I1) - one;
constexpr auto _I1 = SeqType{}.At(I2) - one;
constexpr auto _I2 = SeqType{}.At(I3) - one;
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpec, GemmSpec,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
K1, K1,
K1, K1,
MPerXdl, MPerXdl,
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, decltype(ShuffleSequenceAndTransformFrom4DTo3D<
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterLengths_K0_M_K1>()),
ABlockTransferSrcAccessOrder, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
ABlockTransferSrcVectorDim, ABlockTransferThreadClusterArrangeOrder>()),
ABlockTransferSrcScalarPerVector, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<ABlockTransferSrcAccessOrder>()),
ABlockTransferDstScalarPerVector_K1, ABlockTransferSrcVectorDim,
false, ABlockTransferSrcScalarPerVector,
ABlockLdsAddExtraM, ABlockTransferDstScalarPerVector_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, false,
BBlockTransferThreadClusterArrangeOrder, ABlockLdsAddExtraM,
BBlockTransferSrcAccessOrder, decltype(ShuffleSequenceAndTransformFrom4DTo3D<
BBlockTransferSrcVectorDim, BBlockTransferThreadClusterLengths_K0_N_K1>()),
BBlockTransferSrcScalarPerVector, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
BBlockTransferDstScalarPerVector_K1, BBlockTransferThreadClusterArrangeOrder>()),
false, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<BBlockTransferSrcAccessOrder>()),
BBlockLdsAddExtraN, BBlockTransferSrcVectorDim,
CShuffleMXdlPerWavePerShuffle, BBlockTransferSrcScalarPerVector,
CShuffleNXdlPerWavePerShuffle, BBlockTransferDstScalarPerVector_K1,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, false,
CBlockTransferScalarPerVector_NWaveNPerXdl, BBlockLdsAddExtraN,
BlkGemmPipeSched, CShuffleMXdlPerWavePerShuffle,
BlkGemmPipelineVer, CShuffleNXdlPerWavePerShuffle,
ComputeTypeA, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
ComputeTypeB>; CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
......
...@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment