Commit ccf94638 authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Pass 4d sequence and convert to 3d

parent 860433ea
...@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance = ...@@ -48,16 +48,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
......
...@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance = ...@@ -47,16 +47,16 @@ using DeviceConvBwdWeightInstance =
32, // NPerXdl 32, // NPerXdl
2, // MXdlPerWave 2, // MXdlPerWave
2, // NXdlPerWave 2, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
2, // ABlockTransferDstScalarPerVector_K1 2, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_K1 2, // BBlockTransferDstScalarPerVector_K1
......
...@@ -49,16 +49,16 @@ using DeviceConvBwdWeightInstance = ...@@ -49,16 +49,16 @@ using DeviceConvBwdWeightInstance =
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTranstest/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cppferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
......
...@@ -315,14 +315,43 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -315,14 +315,43 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
batch); batch);
} }
template <typename SeqType>
constexpr static auto
ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Remove first element and,
// Convert 4d->3d sequence.
constexpr auto _I0 = SeqType{}.At(I1);
constexpr auto _I1 = SeqType{}.At(I2);
constexpr auto _I2 = SeqType{}.At(I0);
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
template <typename SeqType>
constexpr static auto
TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto)
{
// Skip first element and
// Convert 4d->3d and take away one from seq.
constexpr index_t one = 1;
constexpr auto _I0 = SeqType{}.At(I1) - one;
constexpr auto _I1 = SeqType{}.At(I2) - one;
constexpr auto _I2 = SeqType{}.At(I3) - one;
constexpr auto _Seq = S<_I0, _I1, _I2>();
return _Seq;
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>; using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>; using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm = using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor,
ADataType, ADataType,
...@@ -344,17 +373,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -344,17 +373,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
NPerXdl, NPerXdl,
MXdlPerWave, MXdlPerWave,
NXdlPerWave, NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1, decltype(ShuffleSequenceAndTransformFrom4DTo3D<
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterLengths_K0_M_K1>()),
ABlockTransferSrcAccessOrder, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
ABlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<ABlockTransferSrcAccessOrder>()),
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
false, false,
ABlockLdsAddExtraM, ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1, decltype(ShuffleSequenceAndTransformFrom4DTo3D<
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterLengths_K0_N_K1>()),
BBlockTransferSrcAccessOrder, decltype(TransformSequenceFrom4DTo3dAndReduceByOne<
BBlockTransferThreadClusterArrangeOrder>()),
decltype(TransformSequenceFrom4DTo3dAndReduceByOne<BBlockTransferSrcAccessOrder>()),
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
......
...@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -201,16 +201,16 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
16, // NPerXdl 16, // NPerXdl
1, // MXdlPerWave 1, // MXdlPerWave
1, // NXdlPerWave 1, // NXdlPerWave
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
1, // ABlockTransferSrcVectorDim 1, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_K1 4, // ABlockTransferDstScalarPerVector_K1
false, // ABlockLdsAddExtraM false, // ABlockLdsAddExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim 1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1 4, // BBlockTransferDstScalarPerVector_K1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment