Commit f358055f authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Rewrite Sequence to old style

parent ccf94638
...@@ -221,12 +221,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -221,12 +221,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// TODO make A/B datatype different // TODO make A/B datatype different
using ABDataType = InDataType; using ABDataType = InDataType;
static constexpr auto I0 = Number<0>{}; static inline constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static inline constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static inline constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static inline constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{}; static inline constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static inline constexpr auto I5 = Number<5>{};
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
...@@ -241,7 +241,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -241,7 +241,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
ConvBackwardWeightSpecialization>{}; ConvBackwardWeightSpecialization>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc() -> decltype(auto)
{ {
const ck::index_t dim = 1; const ck::index_t dim = 1;
const ck::index_t batch = 1; const ck::index_t batch = 1;
...@@ -266,7 +266,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -266,7 +266,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc() -> decltype(auto)
{ {
const ck::index_t dim = 1; const ck::index_t dim = 1;
const ck::index_t batch = 1; const ck::index_t batch = 1;
...@@ -291,7 +291,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -291,7 +291,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc() -> decltype(auto)
{ {
const ck::index_t dim = 1; const ck::index_t dim = 1;
const ck::index_t batch = 1; const ck::index_t batch = 1;
...@@ -316,7 +316,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -316,7 +316,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} }
template <typename SeqType> template <typename SeqType>
constexpr static auto [[nodiscard, gnu::always_inline]] inline constexpr static auto
ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4)) ShuffleSequenceAndTransformFrom4DTo3D() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto) -> decltype(auto)
{ {
...@@ -325,12 +325,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -325,12 +325,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
constexpr auto _I0 = SeqType{}.At(I1); constexpr auto _I0 = SeqType{}.At(I1);
constexpr auto _I1 = SeqType{}.At(I2); constexpr auto _I1 = SeqType{}.At(I2);
constexpr auto _I2 = SeqType{}.At(I0); constexpr auto _I2 = SeqType{}.At(I0);
constexpr auto _Seq = S<_I0, _I1, _I2>(); constexpr auto _Seq = Sequence<_I0, _I1, _I2>();
return _Seq; return _Seq;
} }
template <typename SeqType> template <typename SeqType>
constexpr static auto [[nodiscard, gnu::always_inline]] inline constexpr static auto
TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4)) TransformSequenceFrom4DTo3dAndReduceByOne() noexcept(noexcept(SeqType{}.Size() == 4))
-> decltype(auto) -> decltype(auto)
{ {
...@@ -340,7 +340,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -340,7 +340,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
constexpr auto _I0 = SeqType{}.At(I1) - one; constexpr auto _I0 = SeqType{}.At(I1) - one;
constexpr auto _I1 = SeqType{}.At(I2) - one; constexpr auto _I1 = SeqType{}.At(I2) - one;
constexpr auto _I2 = SeqType{}.At(I3) - one; constexpr auto _I2 = SeqType{}.At(I3) - one;
constexpr auto _Seq = S<_I0, _I1, _I2>(); constexpr auto _Seq = Sequence<_I0, _I1, _I2>();
return _Seq; return _Seq;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment