Commit 0eb75e21 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/moe

parents 1b4b640b c8b6b642
...@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1 ...@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
// "wrong!"); "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
......
...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = ...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using WarpGemmMfmaF16F16F32M16N16K32 = using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
...@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = ...@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using WarpGemmMfmaBf16Bf16F32M16N16K32 = using WarpGemmMfmaBf16Bf16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
......
...@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for<0, kKIter, 1>{}([&](auto iKIter) { static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec, Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter]);
}); });
} }
...@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b // c = a * b
auto c_vec = Impl{}( auto c_vec = Impl{}(
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0], reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]); reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b // c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) { static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec, Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter]);
}); });
......
...@@ -15,7 +15,8 @@ template <typename AType, ...@@ -15,7 +15,8 @@ template <typename AType,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
bool TransposeC> bool TransposeC,
bool SwizzleA = false>
struct WarpGemmMfmaDispatcher; struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
...@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float ...@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
...@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float ...@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
...@@ -58,8 +65,15 @@ template <typename AType, ...@@ -58,8 +65,15 @@ template <typename AType,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
bool TransposeC> bool TransposeC,
using WarpGemmMfmaDispatcher = typename impl:: bool SwizzleA = false>
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type; using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
CType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA>::Type;
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public: public:
Argument(const Tensor<InDataType>& input, Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
: input_{input}, : input_{input},
output_{output}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const Tensor<InDataType>& input_; const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<long_index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_; std::vector<long_index_t> output_spatial_lengths_;
private: private:
void initOutputSpatialLengths() void initOutputSpatialLengths()
{ {
constexpr auto input_offset_to_spatial = 3; constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i) for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back( output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + (output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.output_.GetLengths()[0]; const long_index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1]; const long_index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2]; const long_index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Wo + wo; long_index_t row = n * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
const index_t Ho = arg.output_spatial_lengths_[0]; const long_index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Ho * Wo + ho * Wo + wo; long_index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(hi >= 0 && if(hi >= 0 &&
...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
const index_t Do = arg.output_spatial_lengths_[0]; const long_index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1]; const long_index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o) for(long_index_t d_o = 0; d_o < Do; ++d_o)
{ {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
auto di = auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * static_cast<ck::long_index_t>(ho *
...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(y * static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) - arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) - x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
const ck::index_t G = arg.output_.GetLengths()[0]; const ck::long_index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1]; const ck::long_index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2]; const ck::long_index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo = const long_index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX = const long_index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
{ {
return Argument{input, return Argument{input,
output, output,
......
...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input, const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<ck::long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<ck::long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<ck::long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<ck::long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input, const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public: public:
Argument(const Tensor<InDataType>& input, Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
: input_{input}, : input_{input},
output_{output}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
...@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const Tensor<InDataType>& input_; const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<long_index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_; std::vector<long_index_t> output_spatial_lengths_;
private: private:
void initOutputSpatialLengths() void initOutputSpatialLengths()
...@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back( output_spatial_lengths_.push_back(
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + (input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
...@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.input_.GetLengths()[0]; const long_index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1]; const long_index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2]; const long_index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) { auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo; long_index_t row = n * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
...@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
const index_t Ho = arg.output_spatial_lengths_[0]; const long_index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) { auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo; long_index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(hi >= 0 && if(hi >= 0 &&
...@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
const index_t Do = arg.output_spatial_lengths_[0]; const long_index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1]; const long_index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
...@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
const ck::index_t G = arg.input_.GetLengths()[0]; const ck::long_index_t G = arg.input_.GetLengths()[0];
const ck::index_t N = arg.input_.GetLengths()[1]; const ck::long_index_t N = arg.input_.GetLengths()[1];
const ck::index_t C = arg.input_.GetLengths()[2]; const ck::long_index_t C = arg.input_.GetLengths()[2];
const index_t NDoHoWo = const long_index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX = const long_index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) && if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
...@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
{ {
return Argument{input, return Argument{input,
output, output,
......
...@@ -18,134 +18,82 @@ namespace device { ...@@ -18,134 +18,82 @@ namespace device {
namespace instance { namespace instance {
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row, std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col, Col,
Tuple<Row, Col>, Tuple<Row, Col>,
Row, Row,
F8, F8,
F8, F8,
Tuple<F32, F32>, Tuple<F32, F32>,
BF16, BF16,
PassThrough, PassThrough,
PassThrough, PassThrough,
MultiplyMultiply>>>& instances); MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif #endif
template <typename ADataType, template <typename ADataType,
...@@ -154,7 +102,7 @@ template <typename ADataType, ...@@ -154,7 +102,7 @@ template <typename ADataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout, ALayout,
BLayout, BLayout,
Tuple<Row, Col>, Tuple<Row, Col>,
...@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -167,17 +115,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>> ck::tensor_operation::element_wise::MultiplyMultiply>>
{ {
using DeviceOp = DeviceGemmMultipleD<ALayout, using DeviceOp =
BLayout, DeviceGemmMultipleDSplitK<ALayout,
Tuple<Row, Col>, BLayout,
CLayout, Tuple<Row, Col>,
ADataType, CLayout,
BDataType, ADataType,
Tuple<F32, F32>, BDataType,
CDataType, Tuple<F32, F32>,
ck::tensor_operation::element_wise::PassThrough, CDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>; ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -194,24 +143,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
} }
#endif #endif
......
...@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( ...@@ -77,16 +77,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance ...@@ -97,11 +87,6 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance ...@@ -111,13 +96,8 @@ void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif #endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( ...@@ -177,16 +157,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances ...@@ -196,12 +166,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances ...@@ -212,10 +176,6 @@ void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( ...@@ -275,16 +235,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances ...@@ -295,11 +245,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
...@@ -309,11 +254,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances ...@@ -309,11 +254,6 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
...@@ -376,93 +316,98 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc ...@@ -376,93 +316,98 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instanc
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances(
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances( void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>& DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
...@@ -532,28 +477,20 @@ struct DeviceOperationInstanceFactory< ...@@ -532,28 +477,20 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
} }
#endif #endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
{ {
...@@ -562,21 +499,14 @@ struct DeviceOperationInstanceFactory< ...@@ -562,21 +499,14 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
...@@ -608,21 +538,14 @@ struct DeviceOperationInstanceFactory< ...@@ -608,21 +538,14 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
...@@ -684,51 +607,55 @@ struct DeviceOperationInstanceFactory< ...@@ -684,51 +607,55 @@ struct DeviceOperationInstanceFactory<
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
} }
#endif #endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) #if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> && if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>) is_same_v<CDataType, bhalf_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{ {
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs); add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs); op_ptrs);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
} }
} }
#endif #endif
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -29,8 +29,6 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -29,8 +29,6 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvFwdDefault = static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd3x3 = ConvolutionForwardSpecialization::Filter3x3;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial, template <index_t NDimSpatial,
...@@ -39,16 +37,16 @@ template <index_t NDimSpatial, ...@@ -39,16 +37,16 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1 // generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsLayout, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
...@@ -58,16 +56,16 @@ template <index_t NDimSpatial, ...@@ -58,16 +56,16 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1 // generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsLayout, F16, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
...@@ -77,19 +75,18 @@ template <index_t NDimSpatial, ...@@ -77,19 +75,18 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
ConvolutionForwardSpecialization ConvSpec> ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple<
// clang-format off // clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1 // generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, DsLayout, F32, PassThrough, PassThrough, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
// clang-format on // clang-format on
>; >;
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#endif #endif
#ifdef CK_USE_XDL #ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc" #include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc" #include "grouped_convolution_forward_xdl_large_tensor.inc"
#include "grouped_convolution_forward_comp_xdl.inc" #include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc" #include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc" #include "grouped_convolution_forward_mem_intra_xdl.inc"
...@@ -200,7 +200,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -200,7 +200,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, float>) is_same_v<BComputeType, float>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances(
...@@ -215,7 +215,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -215,7 +215,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, half_t>) is_same_v<BComputeType, half_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances(
...@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -232,7 +232,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, ck::bhalf_t>) is_same_v<BComputeType, ck::bhalf_t>)
{ {
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
...@@ -291,7 +291,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -291,7 +291,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, float>) is_same_v<BComputeType, float>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances(
...@@ -347,7 +347,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -347,7 +347,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, half_t>) is_same_v<BComputeType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances(
...@@ -364,7 +364,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -364,7 +364,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<BComputeType, ck::bhalf_t>) is_same_v<BComputeType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs); op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
......
...@@ -10,7 +10,7 @@ namespace instance { ...@@ -10,7 +10,7 @@ namespace instance {
// grouped conv2d forward, NHWGC/GKYXC/NHWGK // grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
...@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst ...@@ -26,7 +26,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_inst
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
...@@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta ...@@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC, NHWGC,
GKYXC, GKYXC,
...@@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta ...@@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_i ...@@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_i
#endif #endif
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
...@@ -91,7 +91,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_in ...@@ -91,7 +91,7 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_in
#endif #endif
#ifdef CK_ENABLE_FP32 #ifdef CK_ENABLE_FP32
void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
...@@ -146,7 +146,7 @@ check_err(const Range& out, ...@@ -146,7 +146,7 @@ check_err(const Range& out,
bool res{true}; bool res{true};
int err_count = 0; int err_count = 0;
double err = 0; double err = 0;
double max_err = std::numeric_limits<ranges::range_value_t<Range>>::min(); double max_err = NumericLimits<ranges::range_value_t<Range>>::Min();
for(std::size_t i = 0; i < ref.size(); ++i) for(std::size_t i = 0; i < ref.size(); ++i)
{ {
const double o = type_convert<float>(*std::next(std::begin(out), i)); const double o = type_convert<float>(*std::next(std::begin(out), i));
...@@ -178,7 +178,9 @@ check_err(const Range& out, ...@@ -178,7 +178,9 @@ check_err(const Range& out,
template <typename Range, typename RefRange> template <typename Range, typename RefRange>
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> && std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
std::is_integral_v<ranges::range_value_t<Range>> && std::is_integral_v<ranges::range_value_t<Range>> &&
!std::is_same_v<ranges::range_value_t<Range>, bhalf_t>) !std::is_same_v<ranges::range_value_t<Range>, bhalf_t> &&
!std::is_same_v<ranges::range_value_t<Range>, f8_t> &&
!std::is_same_v<ranges::range_value_t<Range>, bf8_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<ranges::range_value_t<Range>, int4_t> || std::is_same_v<ranges::range_value_t<Range>, int4_t>
#endif #endif
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -31,23 +31,35 @@ struct ConvParam ...@@ -31,23 +31,35 @@ struct ConvParam
const std::vector<ck::index_t>& left_pads, const std::vector<ck::index_t>& left_pads,
const std::vector<ck::index_t>& right_pads); const std::vector<ck::index_t>& right_pads);
ck::index_t num_dim_spatial_; ConvParam(ck::long_index_t n_dim,
ck::index_t G_; ck::long_index_t group_count,
ck::index_t N_; ck::long_index_t n_batch,
ck::index_t K_; ck::long_index_t n_out_channels,
ck::index_t C_; ck::long_index_t n_in_channels,
const std::vector<ck::long_index_t>& filters_len,
std::vector<ck::index_t> filter_spatial_lengths_; const std::vector<ck::long_index_t>& input_len,
std::vector<ck::index_t> input_spatial_lengths_; const std::vector<ck::long_index_t>& strides,
std::vector<ck::index_t> output_spatial_lengths_; const std::vector<ck::long_index_t>& dilations,
const std::vector<ck::long_index_t>& left_pads,
std::vector<ck::index_t> conv_filter_strides_; const std::vector<ck::long_index_t>& right_pads);
std::vector<ck::index_t> conv_filter_dilations_;
ck::long_index_t num_dim_spatial_;
std::vector<ck::index_t> input_left_pads_; ck::long_index_t G_;
std::vector<ck::index_t> input_right_pads_; ck::long_index_t N_;
ck::long_index_t K_;
std::vector<ck::index_t> GetOutputSpatialLengths() const; ck::long_index_t C_;
std::vector<ck::long_index_t> filter_spatial_lengths_;
std::vector<ck::long_index_t> input_spatial_lengths_;
std::vector<ck::long_index_t> output_spatial_lengths_;
std::vector<ck::long_index_t> conv_filter_strides_;
std::vector<ck::long_index_t> conv_filter_dilations_;
std::vector<ck::long_index_t> input_left_pads_;
std::vector<ck::long_index_t> input_right_pads_;
std::vector<ck::long_index_t> GetOutputSpatialLengths() const;
std::size_t GetFlops() const; std::size_t GetFlops() const;
......
...@@ -96,9 +96,16 @@ struct HostTensorDescriptor ...@@ -96,9 +96,16 @@ struct HostTensorDescriptor
this->CalculateStrides(); this->CalculateStrides();
} }
HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens)
: mLens(lens.begin(), lens.end())
{
this->CalculateStrides();
}
template <typename Lengths, template <typename Lengths,
typename = std::enable_if_t< typename = std::enable_if_t<
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t>>> std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> ||
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t>>>
HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
{ {
this->CalculateStrides(); this->CalculateStrides();
...@@ -114,11 +121,19 @@ struct HostTensorDescriptor ...@@ -114,11 +121,19 @@ struct HostTensorDescriptor
{ {
} }
HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens,
const std::initializer_list<ck::long_index_t>& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{
}
template <typename Lengths, template <typename Lengths,
typename Strides, typename Strides,
typename = std::enable_if_t< typename = std::enable_if_t<
std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> && (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>>> std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>) ||
(std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t> &&
std::is_convertible_v<ck::ranges::range_value_t<Strides>, ck::long_index_t>)>>
HostTensorDescriptor(const Lengths& lens, const Strides& strides) HostTensorDescriptor(const Lengths& lens, const Strides& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
......
...@@ -64,6 +64,13 @@ function(add_instance_library INSTANCE_NAME) ...@@ -64,6 +64,13 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
# Do not build mha instances if gfx94 targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
set(INST_OBJ) set(INST_OBJ)
...@@ -74,9 +81,11 @@ function(add_instance_library INSTANCE_NAME) ...@@ -74,9 +81,11 @@ function(add_instance_library INSTANCE_NAME)
set(INST_TARGETS ${GPU_TARGETS}) set(INST_TARGETS ${GPU_TARGETS})
endif() endif()
if(source MATCHES "_xdl") if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma") elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif() endif()
set(offload_targets) set(offload_targets)
foreach(target IN LISTS INST_TARGETS) foreach(target IN LISTS INST_TARGETS)
...@@ -86,7 +95,29 @@ function(add_instance_library INSTANCE_NAME) ...@@ -86,7 +95,29 @@ function(add_instance_library INSTANCE_NAME)
list(APPEND INST_OBJ ${source}) list(APPEND INST_OBJ ${source})
endforeach() endforeach()
add_library(${INSTANCE_NAME} OBJECT ${INST_OBJ}) add_library(${INSTANCE_NAME} OBJECT ${INST_OBJ})
# Allow comparing floating points directly in order to check sentinel values
if(${INSTANCE_NAME} STREQUAL "device_mha_instance")
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
endif()
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(device_mha_instance PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
endif()
target_compile_features(${INSTANCE_NAME} PUBLIC) target_compile_features(${INSTANCE_NAME} PUBLIC)
# flags to compress the library
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding --offload-compress flag for ${INSTANCE_NAME}")
target_compile_options(${INSTANCE_NAME} PRIVATE --offload-compress)
endif()
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME}) clang_tidy_check(${INSTANCE_NAME})
set(result 0) set(result 0)
...@@ -286,20 +317,22 @@ if(CK_DEVICE_CONV_INSTANCES) ...@@ -286,20 +317,22 @@ if(CK_DEVICE_CONV_INSTANCES)
) )
endif() endif()
if(CK_DEVICE_MHA_INSTANCES) if(CK_DEVICE_MHA_INSTANCES)
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES}) set(gpu_list ${INST_TARGETS})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations) list(FILTER gpu_list INCLUDE REGEX "^gfx94")
target_compile_features(device_mha_operations PUBLIC) if(gpu_list)
set_target_properties(device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
target_include_directories(device_mha_operations PUBLIC add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/ck/library/tensor_operation_instance/gpu/mha> target_compile_features(device_mha_operations PUBLIC)
) set_target_properties(device_mha_operations PROPERTIES POSITION_INDEPENDENT_CODE ON)
rocm_install(TARGETS device_mha_operations
EXPORT device_mha_operationsTargets) rocm_install(TARGETS device_mha_operations
rocm_install(EXPORT device_mha_operationsTargets EXPORT device_mha_operationsTargets)
FILE composable_kerneldevice_mha_operationsTargets.cmake rocm_install(EXPORT device_mha_operationsTargets
NAMESPACE composable_kernel:: FILE composable_kerneldevice_mha_operationsTargets.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel NAMESPACE composable_kernel::
) DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
endif()
endif() endif()
if(CK_DEVICE_CONTRACTION_INSTANCES) if(CK_DEVICE_CONTRACTION_INSTANCES)
add_library(device_contraction_operations STATIC ${CK_DEVICE_CONTRACTION_INSTANCES}) add_library(device_contraction_operations STATIC ${CK_DEVICE_CONTRACTION_INSTANCES})
......
...@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES ...@@ -111,6 +111,7 @@ list(APPEND GEMM_INSTANCES
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES list(APPEND GEMM_INSTANCES
device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp
......
...@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES ...@@ -11,4 +11,9 @@ list(APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
) )
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES}) add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment