Commit ae3d6cb6 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed block_sync_lds

parents 760b0c75 b62926dc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace ck {
namespace tensor_operation {
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template <index_t NumDimG,
index_t NumDimM,
index_t NumDimN,
device::TensorSpecialization TensorSpec>
__host__ __device__ static auto
MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
{
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
// {
// throw std::runtime_error("wrong! dimension must match input lengths");
// }
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto gs_ms_ns_lengths =
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto gs_ms_ns_strides =
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
// lengths for N0, N1, ...
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
{
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(G, M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
make_tuple(M, N),
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const auto grid_desc_g_mraw_nraw =
transform_tensor_descriptor(grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto c_ms_ns_lengths = to_tuple(
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
}
}
template <typename NumDims_G_M_N_K_O, // Sequence<>
typename PerBlock_M_N_K_O, // Sequence<>
device::GemmSpecialization GemmSpec,
device::TensorSpecialization ASpec,
device::TensorSpecialization B0Spec,
device::TensorSpecialization B1Spec,
device::TensorSpecialization CSpec>
struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
static constexpr auto matrix_padder =
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
//
// A
//
__host__ __device__ static auto MakeAGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
a_gs_ms_ks_strides_vec);
}
// TODO: rename to G_MRaw_KRaw
__host__ __device__ static auto MakeAGridDescriptor_G_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
}
__host__ __device__ static auto MakeAGridDescriptor_M_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{
return matrix_padder.PadADescriptor_M_K(
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
}
template <typename AGridDesc_M_K, typename Number>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename AGridDesc_M_K,
typename WmmaK,
typename MRepeat,
typename MWaves,
typename MPerWmma,
typename AK1>
__host__ __device__ static constexpr auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(
const AGridDesc_M_K& a_grid_desc_m_k,
const WmmaK&,
const MRepeat&,
const MWaves&,
const MPerWmma&,
const AK1&)
{
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AKWmma = K / WmmaK{};
constexpr auto AKRow = 2;
constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(
make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// B (alias of B0)
//
__host__ __device__ static auto MakeB0GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
b0_gs_ns_ks_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
__host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
}
__host__ __device__ static auto MakeB0GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return matrix_padder.PadBDescriptor_N_K(
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
}
template <typename BGridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_L_K,
typename WmmaK,
typename LRepeat,
typename LWaves,
typename LPerWmma,
typename BK1>
__host__ __device__ static constexpr auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(
const BGridDesc_L_K& b_grid_desc_l_k,
const WmmaK&,
const LRepeat&,
const LWaves&,
const LPerWmma&,
const BK1&)
{
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
const auto K = b_grid_desc_l_k.GetLength(I1);
const auto BKWmma = K / WmmaK{};
constexpr auto BKRow = 2;
constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
return transform_tensor_descriptor(
b_grid_desc_l_k,
make_tuple(make_unmerge_transform(
make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// B1
//
__host__ __device__ static auto MakeB1GridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
b1_gs_os_ns_strides_vec);
}
// TODO: rename to G_NRaw_KRaw
__host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
}
__host__ __device__ static auto MakeB1GridDescriptor_N_K(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return matrix_padder.PadB1Descriptor_N_K(
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
}
template <typename B1GridDesc_N_K, typename Number>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
{
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename BGridDesc_N_L,
typename WmmaL,
typename NRepeat,
typename NWaves,
typename NPerWmma,
typename BL1>
__host__ __device__ static constexpr auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(
const BGridDesc_N_L& b_grid_desc_n_l,
const WmmaL&,
const NRepeat&,
const NWaves&,
const NPerWmma&,
const BL1&)
{
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
const auto L = b_grid_desc_n_l.GetLength(I1);
const auto BLWmma = L / WmmaL{};
constexpr auto BLRow = 2;
constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
return transform_tensor_descriptor(
b_grid_desc_n_l,
make_tuple(make_unmerge_transform(
make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
//
// C
//
__host__ __device__ static auto MakeCGridDescriptorPair(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
c_gs_ms_os_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
__host__ __device__ static auto MakeCGridDescriptor_G_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
{
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
};
} // namespace tensor_operation
} // namespace ck
......@@ -417,7 +417,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
......
......@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
#endif
}
......@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
#endif
}
......@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
c3);
}
// Ranged input operand
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
{
#if defined(__gfx11__)
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
#else
ignore = a;
ignore = b;
ignore = c;
#endif
}
} // namespace ck
#endif
......@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<uint8_t>
{
using type = uint8_t;
static constexpr index_t vector_size = 1;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
......@@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
template <typename T>
struct NumericLimits
......
......@@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32);
}
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, float>(float x)
{
union
{
float fp32;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr float type_convert_sp<float, int>(int x)
{
union
{
int int32;
float fp32;
} u = {x};
return u.fp32;
}
template <>
inline __host__ __device__ constexpr int type_convert_sp<int, half_t>(half_t x)
{
union
{
half_t fp16;
int int32;
} u = {x};
return u.int32;
}
template <>
inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
{
union
{
int int32;
half_t fp16;
} u = {x};
return u.fp16;
}
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
......
......@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
Sequence<true>,
Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
......@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
const auto dst_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
auto transfer = ThreadwiseTensorSliceTransfer_v2<
std::remove_const_t<typename SrcTensorType::TensorElementType>,
std::remove_const_t<typename DstTensorType::TensorElementType>,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
I1,
false,
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_slice_origin_idxs,
dst_tensor.GetBuffer());
}
else
......@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType,
typename ThreadLayoutTuple>
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
typename ThreadShape,
typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
{
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
......@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2(
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq =
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
// Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7<
......
......@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
......@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template <typename DataType,
......@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor)
{
constexpr auto I3 = Number<3>{};
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
......@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
......@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
auto sliced_desc = transform_tensor_descriptor(
partition_desc,
make_tuple(
make_slice_transform(partition_shape.At(Number<0>{}),
m_thread_data_on_grid_idx[I0],
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<1>{}),
n_thread_data_on_grid_idx[I0],
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<2>{}),
m_thread_data_on_grid_idx[I1],
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<3>{}),
n_thread_data_on_grid_idx[I1],
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<4>{}),
m_thread_data_on_grid_idx[I2],
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
make_slice_transform(partition_shape.At(Number<5>{}),
m_thread_data_on_grid_idx[I3],
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
make_slice_transform(partition_shape.At(Number<6>{}),
m_thread_data_on_grid_idx[I4],
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
make_slice_transform(partition_shape.At(Number<7>{}),
n_thread_data_on_grid_idx[I2],
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
lower_upper_dims,
lower_upper_dims);
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
partition_shape, partition_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor;
}
......@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc);
// Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType =
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout);
}
......
......@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
const UnrolledDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
......
......@@ -20,48 +20,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template <index_t MPerXDLValue,
index_t NPerXDLValue,
index_t MXdlPerWaveValue,
index_t NXdlPerWaveValue,
index_t K1Value>
template <typename MPerXDLValue,
typename NPerXDLValue,
typename MXdlPerWaveValue,
typename NXdlPerWaveValue,
typename K1Value>
struct BlockwisGemmXdlTraits
{
static constexpr index_t MPerXDL = MPerXDLValue;
static constexpr index_t NPerXDL = NPerXDLValue;
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
static constexpr index_t K1 = K1Value;
static constexpr auto MPerXDL = MPerXDLValue{};
static constexpr auto NPerXDL = NPerXDLValue{};
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
static constexpr auto K1 = K1Value{};
};
// K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
{
};
// K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
{
};
// K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
{
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
} // namespace wrapper
} // namespace ck
......@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace wrapper {
......@@ -29,6 +30,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
namespace detail {
/**
* \brief Generate packed (column-major) strides if not passed
*
......@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace detail
} // namespace
/// @endcond
......@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, strides));
}
/**
......@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
......@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
template <index_t idx, typename Shape, typename UnrolledDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
{
const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape);
......@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return layout.GetShape();
}
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template <typename Shape, typename UnrolledDesc, typename TileLengths>
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
const TileLengths& tile_lengths)
{
auto& unrolled_desc = layout.GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
// Create layout
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
const NewLengths& new_lengths,
[[maybe_unused]] const NewIdxs& new_indexes)
{
const auto& layout_shape = shape(layout);
auto& unrolled_desc = layout.GetUnrolledDescriptor();
constexpr auto dims = Shape::Size();
// Generate transforms
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == Idx)
{
return make_unmerge_transform(new_lengths);
}
else
{
return make_pass_through_transform(layout_shape.At(i));
}
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
{
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
return to_sequence(idxs_tuple);
}
else
{
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
return Sequence<index>{};
}
},
Number<dims>{});
const auto unmerged_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
const auto unmerged_shape =
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
// Create layout
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
}
} // namespace wrapper
} // namespace ck
......@@ -6,7 +6,6 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
......@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Multi index after projection.
*/
template <typename MultiIndex, typename ProjectionTuple>
......@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
}
else
{
return base_tuple.At(i_num);
return make_tuple(base_tuple.At(i_num));
}
},
Number<MultiIndex::Size()>{});
......@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Shape with dims from projection
*/
template <typename... Ts, typename... Ps>
......@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape,
const Tuple<Ps...>& projection)
const Tuple<Ls...>& tile_shape)
{
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple(
[&](auto i) {
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
Number<Tuple<Ls...>::Size()>{});
}
......@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
/**
* \brief Select dims to partition (skip if slice).
*
* \param block_idxs Input block indexes.
* \return Partitioned dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
{
const auto dims_to_partition = generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return Number<i>{};
}
else
{
return Tuple<>{};
}
},
Number<BlockIdxs::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(dims_to_partition);
}
/**
* \brief Replace slices with zeros (Slice dims are not partitioned).
*
* \param block_idxs Input block indexes.
* \return Parsed dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
{
return generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return block_idxs.At(i);
}
else
{
return Number<0>{};
}
},
Number<BlockIdxs::Size()>{});
}
/**
* \brief Calculate default projection.
*
......@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
}
/**
* \brief Calculate thread multi index from 1d thread index.
*
* \param thread_layout Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Multi index.
*/
template <typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id)
{
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
"Thread layout should not be transformed.");
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
constexpr auto shape = ThreadShape{};
constexpr auto strides = embed_transform.coefficients_;
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
return (thread_id / strides.At(num_i)) % shape.At(num_i);
},
Number<ThreadShape::Size()>{});
}
} // namespace detail
} // namespace
......@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_layout Layout of threads (could not be transformed).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
template <typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc,
typename ProjectionTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
static_assert(!IsNestedTuple(ThreadShape{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths
constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
constexpr auto partition_shape =
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor
constexpr auto partition_shape_seq =
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
// Apply projection on thread idxs to remove not needed idxs
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Slice descriptor
const auto transforms = generate_tuple(
[&](auto i) {
return make_slice_transform(partition_shape.At(i),
offset_multi_idxs.At(i),
partition_shape.At(i) + offset_multi_idxs.At(i));
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
partition_shape, unrolled_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
......@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
const index_t thread_id)
{
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
return make_local_partition(tensor, thread_lengths, thread_id, projection);
}
......@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs,
typename ProjectionTuple>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const index_t block_id,
const BlockIdxs& block_idxs,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
static_assert(!IsNestedTuple(BlockIdxs{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
// Number of dims which are partitioned
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
if constexpr(decltype(dims_to_partition)::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto shape_with_projection_dims =
detail::CalculateShapeWithProjection(shape(tensor), projection);
// Set Value for M, N partition
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// Get 1D block id
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
// Optimized version for 2d tile shape [MxN]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
NPerBlock,
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// Apply 0 for non partitioned dims
const auto offset_multi_idxs = generate_tuple(
[&](auto i) {
if constexpr(i == dims_to_partition.At(I0))
{
return m_block_data_idx_on_grid;
}
else if constexpr(i == dims_to_partition.At(I1))
{
return n_block_data_idx_on_grid;
}
else
{
return Number<0>{};
}
},
Number<BlockShapeTuple::Size()>{});
const auto projected_offset_multi_idxs =
detail::ApplyProjection(offset_multi_idxs, projection);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
projected_tile_shape, aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
const auto projected_block_idxs =
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
......@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxs& block_idxs)
{
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
return make_local_tile(tensor, tile_shape, block_idxs, projection);
}
} // namespace wrapper
......
......@@ -133,6 +133,252 @@ struct ReferenceBatchedGemm : public device::BaseOperator
}
};
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceBatchedGemm_MQA : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_1_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g0_g1_m_k_{a_g0_g1_m_k},
b_g0_1_k_n_{b_g0_1_k_n},
c_g0_g1_m_n_{c_g0_g1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g0_g1_m_k_;
const Tensor<BDataType>& b_g0_1_k_n_;
Tensor<CDataType>& c_g0_g1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemm_MQA::Argument;
float Run(const Argument& arg)
{
auto f_g0g1mk_g01kn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) {
const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k));
arg.b_element_op_(v_b, arg.b_g0_1_k_n_(g0, 0, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_g0g1mk_g01kn_g0g1mn,
arg.c_g0_g1_m_n_.mDesc.GetLengths()[0],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[1],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[2],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_1_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_g0_g1_m_k, b_g0_1_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemm_MQA"
<< std::endl;
// clang-format on
return str.str();
}
};
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
ck::index_t QueryGroupNumber>
struct ReferenceBatchedGemm_GQA : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_gq_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_g0_g1_m_k_{a_g0_g1_m_k},
b_g0_gq_k_n_{b_g0_gq_k_n},
c_g0_g1_m_n_{c_g0_g1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_g0_g1_m_k_;
const Tensor<BDataType>& b_g0_gq_k_n_;
Tensor<CDataType>& c_g0_g1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceBatchedGemm_GQA::Argument;
float Run(const Argument& arg)
{
auto f_g0g1mk_g0gqkn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) {
const int G1 = arg.a_g0_g1_m_k_.mDesc.GetLengths()[1];
const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k));
arg.b_element_op_(v_b, arg.b_g0_gq_k_n_(g0, g1 * QueryGroupNumber / G1, k, n));
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_g0g1mk_g0gqkn_g0g1mn,
arg.c_g0_g1_m_n_.mDesc.GetLengths()[0],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[1],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[2],
arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_g0_g1_m_k,
const Tensor<BDataType>& b_g0_gq_k_n,
Tensor<CDataType>& c_g0_g1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_g0_g1_m_k, b_g0_gq_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceBatchedGemm_GQA"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -25,25 +25,35 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
Argument(
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors,
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors,
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors)
: input_{input},
weight_{weight},
output_{output},
elementwise_a_tensors_{elementwise_a_tensors},
elementwise_b_tensors_{elementwise_b_tensors},
elementwise_d_tensors_{elementwise_d_tensors},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
......@@ -58,6 +68,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& output_;
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors_;
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
......@@ -106,26 +120,46 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for(std::size_t k = 0; k < K; ++k)
{
float v_out = 0;
float v_wei = 0;
arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
arg.wei_element_op_(
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x)));
v_acc += v_out * v_wei;
OutDataType v_out;
WeiDataType v_wei;
ExecuteElementwiseOp(arg.out_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_out,
arg.output_(g, n, k, wo),
g,
n,
k,
wo);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, x),
g,
k,
c,
x);
v_acc += ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
}
}
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in);
InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, wi);
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
wi);
};
make_ParallelTensorFunctor(f_ncw,
......@@ -175,20 +209,34 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for(std::size_t k = 0; k < K; ++k)
{
float v_out = 0;
float v_wei = 0;
OutDataType v_out;
WeiDataType v_wei;
arg.out_element_op_(
ExecuteElementwiseOp(
arg.out_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_out,
ck::type_convert<float>(
arg.output_(g, n, k, ho, wo)));
arg.wei_element_op_(
arg.output_(g, n, k, ho, wo),
g,
n,
k,
ho,
wo);
ExecuteElementwiseOp(
arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
ck::type_convert<float>(
arg.weight_(g, k, c, y, x)));
v_acc += v_out * v_wei;
arg.weight_(g, k, c, y, x),
g,
k,
c,
y,
x);
v_acc += ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
}
}
}
......@@ -197,11 +245,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, hi, wi);
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
hi,
wi);
};
make_ParallelTensorFunctor(f_nchw,
......@@ -270,20 +325,37 @@ struct ReferenceConvBwdData : public device::BaseOperator
{
for(std::size_t k = 0; k < K; ++k)
{
float v_out = 0;
float v_wei = 0;
OutDataType v_out;
WeiDataType v_wei;
arg.out_element_op_(
ExecuteElementwiseOp(
arg.out_element_op_,
arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
v_out,
ck::type_convert<float>(arg.output_(
g, n, k, do_, ho, wo)));
arg.wei_element_op_(
arg.output_(g, n, k, do_, ho, wo),
g,
n,
k,
do_,
ho,
wo);
ExecuteElementwiseOp(
arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
ck::type_convert<float>(
arg.weight_(g, k, c, z, y, x)));
v_acc += v_out * v_wei;
arg.weight_(g, k, c, z, y, x),
g,
k,
c,
z,
y,
x);
v_acc +=
ck::type_convert<float>(v_out) *
ck::type_convert<float>(v_wei);
}
}
}
......@@ -295,11 +367,19 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
InDataType v_acc_converted = ck::type_convert<InDataType>(v_acc);
InDataType& v_in = arg.input_(g, n, c, di, hi, wi);
ExecuteElementwiseOp(arg.in_element_op_,
arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
v_in,
v_acc_converted,
g,
n,
c,
di,
hi,
wi);
};
make_ParallelTensorFunctor(f_ncdhw,
......@@ -325,6 +405,36 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
};
template <typename... Args,
typename ElementwiseOp,
typename ElementwiseTensor,
typename NumTensor,
typename T>
static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op,
ElementwiseTensor& elementwise_tensors,
NumTensor,
T& y,
const T& x,
Args... dims)
{
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......@@ -333,16 +443,20 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
static auto MakeArgument(
Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors = {},
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors = {},
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors = {})
{
return Argument{input,
weight,
......@@ -353,7 +467,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
out_element_op,
elementwise_a_tensors,
elementwise_b_tensors,
elementwise_d_tensors};
}
static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename ScaleDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferencefpAintBGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
scale_k_n_{scale_k_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& scale_k_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferencefpAintBGemm::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a;
BDataType v_b;
ScaleDataType v_scale;
ADataType v_converted_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
}
// same for scale matrix
if constexpr(is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_scale,
arg.scale_k_n_(k, n));
}
else
{
arg.b_element_op_(v_scale, arg.scale_k_n_(k, n));
}
v_converted_b = type_convert<ADataType>(v_b) * v_scale;
v_acc += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_converted_b);
}
AccDataType v_c;
arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& scale_k_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -384,6 +384,26 @@ void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
instances);
#endif
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -478,6 +498,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
......@@ -493,6 +514,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
op_ptrs);
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -505,6 +527,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
......@@ -517,6 +540,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
......
......@@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
......@@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs);
add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using BF8 = ck::bf8_t;
using F8 = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
// f16_f16_f32_f16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F16, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// bf16_bf16_f32_bf16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, BF16, BF16, F32, BF16, Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// f32_f32_f32_f32
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F32, F32, F32, F32, Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
// clang-format on
>;
// f16_f16_f16_comp_f8
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>,
// instances for small conv.K and conv.C
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple<ELayout>, ELayout, F16, F16, F32, F32, Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -54,36 +54,36 @@ template <index_t NDSpatial,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 256, 32, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 256, 32, 32, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 128, 32, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 16, 64, 32, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 64, 16, 32, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 32, 32, 32, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
......@@ -97,36 +97,36 @@ template <index_t NDSpatial,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_wmma_i8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//generic instance
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>,
// blocksize=256
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 256, 64, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 64, 64, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 128, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// blocksize=128
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 64, 128, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 64, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 128, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 64, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 128, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 256, 64, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 256, 32, 64, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// blocksize=64
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 64, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 64, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 32, 128, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 128, 64, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
// blocksize=32
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 16, 64, 64, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 64, 16, 64, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 32, 32, 64, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, DsDatatype, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
// clang-format on
>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0;
static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
static constexpr auto ConvFwdOddC =
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_bf16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_f16_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, ck::Tuple<F16>, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F32, F32, F32, F32, ck::Tuple<F32>, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionForwardSpecialization ConvSpec>
using device_grouped_conv_fwd_xdl_bilinear_int8_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
// instances for small conv.K and conv.C
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, int8_t, int8_t, int32_t, int8_t, ck::Tuple<int8_t>, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment