Commit b9eb4de3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents be3fbf7f d22713a7
...@@ -14,22 +14,15 @@ ...@@ -14,22 +14,15 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
// function to be used on device, emulates std::accumulate
template <typename T, typename ForwardIterator, typename Size>
__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init)
{
for(ForwardIterator x = first; x != first + count; x++)
{
init *= *x;
}
return init;
}
template <index_t NDimSpatial, template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization, device::ConvolutionForwardSpecialization ConvForwardSpecialization,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1> index_t NumGroupsToMerge = 1>
struct TransformConvFwdToGemm struct TransformConvFwdToGemm
{ {
private:
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -37,10 +30,10 @@ struct TransformConvFwdToGemm ...@@ -37,10 +30,10 @@ struct TransformConvFwdToGemm
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static long_index_t template <typename ConvDimsType>
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths, static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const std::array<index_t, NDimSpatial + 3>& strides, const ConvDimsType& strides,
index_t i) index_t i)
{ {
long_index_t acc = 1; long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++) for(; i < (NDimSpatial + 3); i++)
...@@ -52,11 +45,11 @@ struct TransformConvFwdToGemm ...@@ -52,11 +45,11 @@ struct TransformConvFwdToGemm
return acc; return acc;
} }
template <typename ADataType, typename CDataType> template <typename ConvDimsType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const ConvDimsType& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const ConvDimsType& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides) const ConvDimsType& c_g_n_k_wos_strides)
{ {
const long_index_t a_element_space_size = const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
...@@ -102,6 +95,216 @@ struct TransformConvFwdToGemm ...@@ -102,6 +95,216 @@ struct TransformConvFwdToGemm
} }
} }
public:
__host__ __device__ constexpr TransformConvFwdToGemm() {}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 1, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{I1},
Wi_{a_g_n_c_wis_lengths[I3]},
Do_{I1},
Ho_{I1},
Wo_{c_g_n_k_wos_lengths[I3]},
Z_{I1},
Y_{I1},
X_{b_g_k_c_xs_lengths[I3]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{I1},
WiStride_{a_g_n_c_wis_strides[I3]},
WoStride_{c_g_n_k_wos_strides[I3]},
XStride_{b_g_k_c_xs_strides[I3]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{I1},
ConvStrideW_{conv_filter_strides[I0]},
ConvDilationD_{I1},
ConvDilationH_{I1},
ConvDilationW_{conv_filter_dilations[I0]},
InLeftPadD_{I0},
InLeftPadH_{I0},
InLeftPadW_{input_left_pads[I0]},
InRightPadD_{I0},
InRightPadH_{I0},
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Wo_;
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 2, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{a_g_n_c_wis_lengths[I3]},
Wi_{a_g_n_c_wis_lengths[I4]},
Do_{I1},
Ho_{c_g_n_k_wos_lengths[I3]},
Wo_{c_g_n_k_wos_lengths[I4]},
Z_{I1},
Y_{b_g_k_c_xs_lengths[I3]},
X_{b_g_k_c_xs_lengths[I4]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{a_g_n_c_wis_strides[I3]},
WiStride_{a_g_n_c_wis_strides[I4]},
WoStride_{c_g_n_k_wos_strides[I4]},
XStride_{b_g_k_c_xs_strides[I4]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{conv_filter_strides[I0]},
ConvStrideW_{conv_filter_strides[I1]},
ConvDilationD_{I1},
ConvDilationH_{conv_filter_dilations[I0]},
ConvDilationW_{conv_filter_dilations[I1]},
InLeftPadD_{I0},
InLeftPadH_{input_left_pads[I0]},
InLeftPadW_{input_left_pads[I1]},
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Ho_ * Wo_;
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 3, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{a_g_n_c_wis_lengths[I3]},
Hi_{a_g_n_c_wis_lengths[I4]},
Wi_{a_g_n_c_wis_lengths[I5]},
Do_{c_g_n_k_wos_lengths[I3]},
Ho_{c_g_n_k_wos_lengths[I4]},
Wo_{c_g_n_k_wos_lengths[I5]},
Z_{b_g_k_c_xs_lengths[I3]},
Y_{b_g_k_c_xs_lengths[I4]},
X_{b_g_k_c_xs_lengths[I5]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{a_g_n_c_wis_strides[I3]},
HiStride_{a_g_n_c_wis_strides[I4]},
WiStride_{a_g_n_c_wis_strides[I5]},
WoStride_{c_g_n_k_wos_strides[I5]},
XStride_{b_g_k_c_xs_strides[I5]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{conv_filter_strides[I0]},
ConvStrideH_{conv_filter_strides[I1]},
ConvStrideW_{conv_filter_strides[I2]},
ConvDilationD_{conv_filter_dilations[I0]},
ConvDilationH_{conv_filter_dilations[I1]},
ConvDilationW_{conv_filter_dilations[I2]},
InLeftPadD_{input_left_pads[I0]},
InLeftPadH_{input_left_pads[I1]},
InLeftPadW_{input_left_pads[I2]},
InRightPadD_{input_right_pads[I0]},
InRightPadH_{input_right_pads[I1]},
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
...@@ -110,53 +313,26 @@ struct TransformConvFwdToGemm ...@@ -110,53 +313,26 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NWGC> || is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>), is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeADescriptor_M_K() const
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Wi = a_g_n_c_wis_lengths[I3];
const index_t Wo = c_g_n_k_wos_lengths[I3];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const auto CStride = a_g_n_c_wis_strides[I2];
const index_t WiStride = a_g_n_c_wis_strides[I3];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor(make_tuple(NHoWo, C), return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride, CStride)); make_tuple(WiStride_, CStrideTensorA_));
} }
else else
{ {
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc, in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -164,35 +340,30 @@ struct TransformConvFwdToGemm ...@@ -164,35 +340,30 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3) device::ConvolutionForwardSpecialization::Filter3x3)
{ {
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_wi_c_desc = const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride)); make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
const auto in_n_wip_c_desc = transform_tensor_descriptor( const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc, in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW))), make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_x_wo_c_desc, in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(Number<3>{})), make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -200,28 +371,29 @@ struct TransformConvFwdToGemm ...@@ -200,28 +371,29 @@ struct TransformConvFwdToGemm
else else
{ {
const auto in_n_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride)); make_tuple(N_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor( const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc, in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_x_wo_c_desc, in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(Number<3>{})), make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -233,110 +405,108 @@ struct TransformConvFwdToGemm ...@@ -233,110 +405,108 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor( const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_wo_c_desc, in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_wi_c_desc = const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, GStride, CStride)); make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor( const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_wo_c_desc, in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
} }
else else
{ {
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride)); make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor( const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc, in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_x_wo_c_desc, in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_merge_transform(make_tuple(X, C))), make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_wi_c_desc = const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C), make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, GStride, CStride)); make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor( const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc, in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc, in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_x_wo_c_desc, in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(X, C))), make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -349,57 +519,27 @@ struct TransformConvFwdToGemm ...@@ -349,57 +519,27 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NHWGC> || is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>), is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeADescriptor_M_K() const
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const index_t CStride = a_g_n_c_wis_strides[I2];
const index_t HiStride = a_g_n_c_wis_strides[I3];
const index_t WiStride = a_g_n_c_wis_strides[I4];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor(make_tuple(NHoWo, C), return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride, CStride)); make_tuple(WiStride_, CStrideTensorA_));
} }
else else
{ {
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride)); make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc, in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -407,73 +547,65 @@ struct TransformConvFwdToGemm ...@@ -407,73 +547,65 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3) device::ConvolutionForwardSpecialization::Filter3x3)
{ {
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride)); make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc, in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho), make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW))), make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc, in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_hi_wi_groups_c_desc = const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge), make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStride, HiStride, WiStride, GStride)); make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc, in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc, in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho), make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc, in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -485,37 +617,39 @@ struct TransformConvFwdToGemm ...@@ -485,37 +617,39 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc, in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_ho_wo_c_desc, in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumGroupsToMerge, C), make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, HiStride, WiStride, GStride, CStride)); make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc, in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -523,55 +657,44 @@ struct TransformConvFwdToGemm ...@@ -523,55 +657,44 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_ho_wo_groups_c_desc, in_n_ho_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
} }
else else
{ {
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc, in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y, Ho), make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc, in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Y, X, C))), make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -579,16 +702,17 @@ struct TransformConvFwdToGemm ...@@ -579,16 +702,17 @@ struct TransformConvFwdToGemm
{ {
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumGroupsToMerge, C), make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, HiStride, WiStride, GStride, CStride)); make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc, in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -596,13 +720,13 @@ struct TransformConvFwdToGemm ...@@ -596,13 +720,13 @@ struct TransformConvFwdToGemm
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc, in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y, Ho), make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -613,8 +737,8 @@ struct TransformConvFwdToGemm ...@@ -613,8 +737,8 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc, in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y, X, C))), make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -627,63 +751,27 @@ struct TransformConvFwdToGemm ...@@ -627,63 +751,27 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> || is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>), is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeADescriptor_M_K() const
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
{ {
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const index_t CStride = a_g_n_c_wis_strides[I2];
const index_t DiStride = a_g_n_c_wis_strides[I3];
const index_t HiStride = a_g_n_c_wis_strides[I4];
const index_t WiStride = a_g_n_c_wis_strides[I5];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{ {
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride, CStride)); make_tuple(WiStride_, CStrideTensorA_));
} }
else else
{ {
const auto in_gemmm_groups_gemmk_desc = const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C), make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride, GStride, CStride)); make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc, in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -691,41 +779,30 @@ struct TransformConvFwdToGemm ...@@ -691,41 +779,30 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization == else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3) device::ConvolutionForwardSpecialization::Filter3x3)
{ {
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride)); make_tuple(N_, Di_, Hi_, Wi_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW)), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do), make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho), make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW))), make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
...@@ -733,7 +810,7 @@ struct TransformConvFwdToGemm ...@@ -733,7 +810,7 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc, in_n_z_do_y_ho_x_wo_c_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -741,15 +818,15 @@ struct TransformConvFwdToGemm ...@@ -741,15 +818,15 @@ struct TransformConvFwdToGemm
else else
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge), make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride)); make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
...@@ -758,13 +835,13 @@ struct TransformConvFwdToGemm ...@@ -758,13 +835,13 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do), make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho), make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo), make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)), make_pass_through_transform(NumGroupsToMerge)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
...@@ -777,7 +854,7 @@ struct TransformConvFwdToGemm ...@@ -777,7 +854,7 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc, in_n_z_do_y_ho_x_wo_c_desc,
make_tuple( make_tuple(
make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -789,16 +866,16 @@ struct TransformConvFwdToGemm ...@@ -789,16 +866,16 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C), make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -806,25 +883,30 @@ struct TransformConvFwdToGemm ...@@ -806,25 +883,30 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc, in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -840,43 +922,28 @@ struct TransformConvFwdToGemm ...@@ -840,43 +922,28 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc, in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), make_tuple(
make_pass_through_transform(C)), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
} }
else else
{ {
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C), make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple( make_tuple(
...@@ -884,14 +951,14 @@ struct TransformConvFwdToGemm ...@@ -884,14 +951,14 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z, Do), make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y, Ho), make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple( make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
...@@ -902,25 +969,30 @@ struct TransformConvFwdToGemm ...@@ -902,25 +969,30 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc, in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Z, Y, X, C))), make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
else else
{ {
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride)); make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc, in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi, InLeftPadW, InRightPadW), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -936,15 +1008,15 @@ struct TransformConvFwdToGemm ...@@ -936,15 +1008,15 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc, in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N), make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z, Do), make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD, ConvStrideD)), make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y, Ho), make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH, ConvStrideH)), make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X, Wo), make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW, ConvStrideW)), make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)), make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1>{}, Sequence<1>{},
Sequence<2>{}, Sequence<2>{},
...@@ -960,8 +1032,9 @@ struct TransformConvFwdToGemm ...@@ -960,8 +1032,9 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor( return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc, in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)), make_tuple(
make_merge_transform(make_tuple(Z, Y, X, C))), make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -973,19 +1046,8 @@ struct TransformConvFwdToGemm ...@@ -973,19 +1046,8 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::GKYXC> || is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>, is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, __host__ __device__ auto MakeBDescriptor_N_K() const
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const index_t GStride = b_g_k_c_xs_strides[I0];
const index_t KStride = b_g_k_c_xs_strides[I1];
const index_t CStride = b_g_k_c_xs_strides[I2];
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3) device::ConvolutionForwardSpecialization::Filter3x3)
{ {
...@@ -996,17 +1058,17 @@ struct TransformConvFwdToGemm ...@@ -996,17 +1058,17 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{})); return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
} }
else else
{ {
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStride, GStride, CStride)); make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor( return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc, wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(FilterSizeNumType{})), make_pass_through_transform(FilterSizeNumType{})),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -1016,16 +1078,17 @@ struct TransformConvFwdToGemm ...@@ -1016,16 +1078,17 @@ struct TransformConvFwdToGemm
{ {
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C)); return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
} }
else else
{ {
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride)); make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor( return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc, wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(YX * C)), make_pass_through_transform(ZYX_ * C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}), make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -1041,25 +1104,14 @@ struct TransformConvFwdToGemm ...@@ -1041,25 +1104,14 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::KYXGC> || is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>, is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false> bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, __host__ __device__ auto MakeBDescriptor_N_K() const
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride)); make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc, wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))), make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
...@@ -1071,24 +1123,14 @@ struct TransformConvFwdToGemm ...@@ -1071,24 +1123,14 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::GNHWK> || is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>, is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false> bool>::type = false>
static auto __host__ __device__ auto MakeCDescriptor_M_N() const
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{ {
const index_t K = c_g_n_k_wos_lengths[2]; return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
} }
template < template <
typename CLayout, typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> || typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> || is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> || is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
...@@ -1096,39 +1138,28 @@ struct TransformConvFwdToGemm ...@@ -1096,39 +1138,28 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NHWGK> || is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>, is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, __host__ __device__ auto MakeCDescriptor_M_N() const
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t GStride = c_g_n_k_wos_strides[0];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
return make_naive_tensor_descriptor(make_tuple(NHoWo, K), return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
make_tuple(WoStride, KStride)); make_tuple(WoStride_, KStrideTensorC_));
} }
else else
{ {
const auto nhwo_groups_k_1_desc = const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1), make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
make_tuple(WoStride, GStride, KStride, GStride)); make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
// Padd 1 to NumGroupsToMerge // Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor( const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc, nhwo_groups_k_1_desc,
make_tuple(make_pass_through_transform(NHoWo), make_tuple(make_pass_through_transform(NDoHoWo_),
make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K), make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)), make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. Xor returns 0 for the same // We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding. // values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
...@@ -1136,16 +1167,16 @@ struct TransformConvFwdToGemm ...@@ -1136,16 +1167,16 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64); NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor( const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc, padded_desc,
make_tuple(make_pass_through_transform(NHoWo), make_tuple(make_pass_through_transform(NDoHoWo_),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K)), make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N // Merge To M, N
return transform_tensor_descriptor( return transform_tensor_descriptor(
unmerged_padded_desc, unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)), make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(K, NumGroupsToMerge))), make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
} }
...@@ -1155,542 +1186,34 @@ struct TransformConvFwdToGemm ...@@ -1155,542 +1186,34 @@ struct TransformConvFwdToGemm
template <typename CLayout, template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>, typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false> bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, __host__ __device__ auto MakeCDescriptor_M_N() const
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{ {
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// Overloaded functions for hipRTC purposes public:
template <typename ALayout, index_t N_;
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> || private:
is_same_v<ALayout, tensor_layout::convolution::NWGC> || const index_t Di_, Hi_, Wi_;
is_same_v<ALayout, tensor_layout::convolution::GNWC>), const index_t Do_, Ho_, Wo_;
bool>::type = false> const index_t Z_, Y_, X_;
__host__ __device__ static auto const index_t K_, C_;
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const index_t DiStride_, HiStride_, WiStride_;
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const index_t WoStride_;
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const index_t XStride_;
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */, const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const index_t NStrideTensorA_;
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */, const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
const ck::Array<index_t, NDimSpatial>& conv_filter_strides, const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations, const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
const ck::Array<index_t, NDimSpatial>& input_left_pads, const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
const ck::Array<index_t, NDimSpatial>& input_right_pads) const index_t InRightPadD_, InRightPadH_, InRightPadW_;
{ const index_t ZYX_;
const index_t N = a_g_n_c_wis_lengths[1]; index_t NDoHoWo_;
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
return wei_gemmn_gemmk_desc;
}
template <
typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
return out_gemmm_gemmn_desc;
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
return out_gemmm_gemmn_desc;
}
}; };
// wrapper class to call member functions on TransformConvToGemm struct at runtime // wrapper class to call member functions on TransformConvToGemm struct at runtime
...@@ -1702,26 +1225,22 @@ struct TransformConv ...@@ -1702,26 +1225,22 @@ struct TransformConv
template <index_t NDimSpatial, template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization> device::ConvolutionForwardSpecialization ConvForwardSpecialization>
auto auto
transform_func(ck::Array<index_t, NDimSpatial + 3> out_lengths, transform_func(TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
ck::Array<index_t, NDimSpatial + 3> out_strides,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
{ {
if(NDimSpatial == 2) if(NDimSpatial == 2)
{ {
return conv_fwd_to_gemm return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>(out_lengths, .template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
out_strides);
} }
else if(NDimSpatial == 3) else if(NDimSpatial == 3)
{ {
return conv_fwd_to_gemm return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>(out_lengths, .template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
out_strides);
} }
else if(NDimSpatial == 1) else if(NDimSpatial == 1)
{ {
return conv_fwd_to_gemm.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>( return conv_fwd_to_gemm
out_lengths, out_strides); .template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
} }
} }
}; };
......
...@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu; ...@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename A0DataType,
typename A1DataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD_ABScale<
ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
Tuple<>,
CLayout,
A0DataType,
A1DataType,
B0DataType,
B1DataType,
Tuple<>,
CDataType,
128,
128,
128,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<A0DataType, f8_t> && is_same_v<B0DataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -315,7 +315,7 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc ...@@ -315,7 +315,7 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
...@@ -416,6 +416,57 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_ins ...@@ -416,6 +416,57 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_ins
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif #endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
...@@ -596,7 +647,7 @@ struct DeviceOperationInstanceFactory< ...@@ -596,7 +647,7 @@ struct DeviceOperationInstanceFactory<
} }
} }
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> && if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>) is_same_v<CDataType, bhalf_t>)
{ {
...@@ -653,6 +704,33 @@ struct DeviceOperationInstanceFactory< ...@@ -653,6 +704,33 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
} }
} }
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif #endif
return op_ptrs; return op_ptrs;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmV2R1<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#endif #endif
#ifdef CK_USE_XDL #ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc" #include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc"
#include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc" #include "grouped_convolution_forward_mem_intra_xdl.inc"
...@@ -200,8 +199,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -200,8 +199,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, float>) is_same_v<BComputeType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -215,8 +212,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -215,8 +212,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, half_t>) is_same_v<BComputeType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -232,8 +227,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -232,8 +227,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, ck::bhalf_t>) is_same_v<BComputeType, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -291,8 +284,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -291,8 +284,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, float>) is_same_v<BComputeType, float>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -347,8 +338,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -347,8 +338,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, half_t>) is_same_v<BComputeType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
op_ptrs); op_ptrs);
...@@ -364,8 +353,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -364,8 +353,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, ck::bhalf_t>) is_same_v<BComputeType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
op_ptrs); op_ptrs);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
F32,
Empty_Tuple,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GEMM_AB_SCALE_INSTANCES)
list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
)
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
template <GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
// Spill in current compiler
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// Memory friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
Col,
Tuple<>,
Row,
F8,
F32,
F8,
F32,
Tuple<>,
BF16,
128,
128,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
add_instance_library(device_gemm_multiply_multiply_instance ${GEMM_MULTIPLY_MULTIPLY_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF16 = bhalf_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
using MultiplyMultiply = element_wise::MultiplyMultiply;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
// clang-format on
>;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment