Commit 3552041a authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_opt

parents e8927110 733f33af
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#define DEBUG_LOG 0
namespace ck {
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
{
using AScaleType = float;
using BScaleType = float;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
}
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
}
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
}
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
}
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, KBatch_)},
KPadded{CalculateKPadded(K_, KBatch_)},
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC;
index_t KBatch;
index_t MPadded;
index_t NPadded;
index_t KRead;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
p_c_grid{p_c_grid_},
p_a_scale_grid{p_a_scale_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_}
{
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
});
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
DsGridPointer p_ds_grid;
CDataType* p_c_grid;
const AScaleType* p_a_scale_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeB);
;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number<NLdsLayer>{})),
make_pass_through_transform(Number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_pass_through_transform(BK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<NPerBlock / NLdsLayer>{}, Number<NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
else // RowMajor B
{
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
? 1
: ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
? N0
: 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmABScalePipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
b_block_space_size_aligned * sizeof(LDSTypeB)),
c_block_size * sizeof(CShuffleDataType));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = karg.KBatch * KReadVec;
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
return false;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
const AScaleType* p_a_scale_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
LDSTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
const index_t ScaleSliceSizeM = 1;
const index_t ScaleSliceSizeN = 1;
const index_t ScaleSliceSizeK = 1;
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
1,
1,
false>(
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0));
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
1,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_smfmac.hpp"
namespace ck {
enum struct SmfmacInstr
{
smfmac_f32_16x16x32f16 = 0,
smfmac_f32_32x32x16f16,
smfmac_f32_16x16x32bf16,
smfmac_f32_32x32x16bf16,
};
template <SmfmacInstr instr>
struct smfmac_type;
template <>
struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <>
struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
struct SmfmacSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
static constexpr auto GetSmfmac();
template <>
static constexpr auto GetSmfmac<half_t, 16, 16>()
{
return SmfmacInstr::smfmac_f32_16x16x32f16;
}
template <>
static constexpr auto GetSmfmac<half_t, 32, 32>()
{
return SmfmacInstr::smfmac_f32_32x32x16f16;
}
template <>
static constexpr auto GetSmfmac<bhalf_t, 16, 16>()
{
return SmfmacInstr::smfmac_f32_16x16x32bf16;
}
template <>
static constexpr auto GetSmfmac<bhalf_t, 32, 32>()
{
return SmfmacInstr::smfmac_f32_32x32x16bf16;
}
static constexpr auto selected_smfmac =
smfmac_type<GetSmfmac<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
__host__ __device__ constexpr SmfmacSelector()
{
static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk ==
selected_smfmac.num_regs_per_blk,
"wrong! num_regs_per_blk");
static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk,
"n_per_blk != num_threads_per_blk");
static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks ==
selected_smfmac.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks ||
selected_smfmac.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size ==
selected_smfmac.m_per_blk * selected_smfmac.n_per_blk,
"num_regs_per_blk incorrect");
static_assert(selected_smfmac.is_k_reduction ||
(selected_smfmac.num_input_blks == selected_smfmac.num_output_blks),
"is_k_reduction wrong!");
}
static constexpr index_t GetKPerXdlops()
{
return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) *
selected_smfmac.k_per_blk;
}
static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; }
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
index_t KPack,
typename additional_type = base_type>
struct SparseXdlopsGemm
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
using CIndex = MultiIndex<2>;
using CIndex4D = MultiIndex<4>;
__device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks);
}
__host__ __device__ constexpr SparseXdlopsGemm()
{
static_assert(NPerXdlops == 16 || NPerXdlops == 32,
"Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops");
static_assert(MPerXdlops == 16 || MPerXdlops == 32,
"Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops");
static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<smfmac_instr.num_groups_per_blk>{},
Number<smfmac_instr.num_input_blks>{},
Number<smfmac_instr.group_size>{})),
make_pass_through_transform(Number<smfmac_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4, 5, 6>{},
Sequence<7>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
{
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk,
smfmac_instr.num_input_blks,
smfmac_instr.group_size)),
make_pass_through_transform(smfmac_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{},
Sequence<8>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / smfmac_instr.wave_size;
}
__device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; }
template <class FloatA, class FloatB, class Idx, class FloatC>
__device__ void
Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value,
"base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread);
});
}
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; }
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx =
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td);
}
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(smfmac_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
if constexpr(smfmac_instr.is_k_reduction)
{
return make_tuple(blk_id, blk_td);
}
else
{
return make_tuple(0, laneId);
}
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td;
index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size;
return CIndex{m_offset, n_offset};
}
__device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
{
const auto blk_idx = GetBlkIdx();
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
return CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto smfmac =
SmfmacSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
static constexpr auto smfmac_instr = smfmac.selected_smfmac;
static constexpr auto KPerXdlops = smfmac.GetKPerXdlops();
static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops();
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
__host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
{
return make_tuple(
Number<smfmac_instr.num_groups_per_blk>{}, I1, Number<smfmac_instr.group_size>{}, I1);
}
};
} // namespace ck
......@@ -27,7 +27,7 @@ template <index_t NDimSpatial,
index_t NPerBlock,
index_t GemmK1Number,
index_t K0PerBlock,
index_t NumBatchToMerge,
index_t NumGroupsToMerge,
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
struct TransformConvBwdWeightToGemmV2
{
......@@ -45,7 +45,7 @@ struct TransformConvBwdWeightToGemmV2
const index_t BatchStride = output_strides[0];
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K),
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K),
make_tuple(WoStride, BatchStride, KStride));
}
......@@ -65,13 +65,13 @@ struct TransformConvBwdWeightToGemmV2
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C),
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C),
make_tuple(WiStride, BatchStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumBatchToMerge, C),
make_tuple(N, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
}
}
......@@ -88,30 +88,30 @@ struct TransformConvBwdWeightToGemmV2
const auto KStride = weights_strides[1];
const auto XStride = weights_strides[4];
const auto BatchStride = weights_strides[0];
// Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumBatchToMerge, K, Y * X, 1, C),
make_tuple(NumGroupsToMerge, K, Y * X, 1, C),
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
// Padd 1 to NumBatchToMerge
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumBatchToMerge),
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K),
make_pass_through_transform(Y * X),
make_pad_transform(1, 0, NumBatchToMerge - 1),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
NumBatchToMerge == 64);
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K),
make_pass_through_transform(Y * X),
make_pass_through_transform(C)),
......@@ -120,8 +120,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))),
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -138,7 +138,7 @@ struct TransformConvBwdWeightToGemmV2
const index_t BatchStride = output_strides[0];
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K),
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K),
make_tuple(WoStride, BatchStride, KStride));
}
......@@ -160,13 +160,13 @@ struct TransformConvBwdWeightToGemmV2
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C),
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C),
make_tuple(WiStride, BatchStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C),
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride));
}
}
......@@ -184,29 +184,29 @@ struct TransformConvBwdWeightToGemmV2
const auto KStride = weights_strides[1];
const auto XStride = weights_strides[5];
const auto BatchStride = weights_strides[0];
// Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
// Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C),
make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C),
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
// Padd 1 to NumBatchToMerge
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumBatchToMerge),
make_tuple(make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K),
make_pass_through_transform(Z * Y * X),
make_pad_transform(1, 0, NumBatchToMerge - 1),
make_pad_transform(1, 0, NumGroupsToMerge - 1),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
NumBatchToMerge == 64);
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K),
make_pass_through_transform(Z * Y * X),
make_pass_through_transform(C)),
......@@ -215,8 +215,8 @@ struct TransformConvBwdWeightToGemmV2
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))),
make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -262,8 +262,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t InRightPadW = input_right_pads[1];
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K * NumBatchToMerge;
const index_t GemmN = C * X * Y * NumBatchToMerge;
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
......@@ -286,7 +286,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -302,7 +302,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -324,7 +324,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -341,7 +341,7 @@ struct TransformConvBwdWeightToGemmV2
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
......@@ -354,7 +354,7 @@ struct TransformConvBwdWeightToGemmV2
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
......@@ -366,7 +366,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)),
make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
......@@ -465,8 +465,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t InRightPadW = input_right_pads[2];
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K * NumBatchToMerge;
const index_t GemmN = C * Z * X * Y * NumBatchToMerge;
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
......@@ -489,7 +489,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -505,7 +505,7 @@ struct TransformConvBwdWeightToGemmV2
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -527,7 +527,7 @@ struct TransformConvBwdWeightToGemmV2
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -545,7 +545,7 @@ struct TransformConvBwdWeightToGemmV2
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
......@@ -567,7 +567,7 @@ struct TransformConvBwdWeightToGemmV2
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
......@@ -584,7 +584,7 @@ struct TransformConvBwdWeightToGemmV2
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)),
make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
......
......@@ -14,27 +14,26 @@
namespace ck {
namespace tensor_operation {
// function to be used on device, emulates std::accumulate
template <typename T, typename ForwardIterator, typename Size>
__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init)
{
for(ForwardIterator x = first; x != first + count; x++)
{
init *= *x;
}
return init;
}
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1>
struct TransformConvFwdToGemm
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
......@@ -46,11 +45,11 @@ struct TransformConvFwdToGemm
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
template <typename ConvDimsType>
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
......@@ -96,640 +95,421 @@ struct TransformConvFwdToGemm
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
public:
__host__ __device__ constexpr TransformConvFwdToGemm() {}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 1, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{I1},
Wi_{a_g_n_c_wis_lengths[I3]},
Do_{I1},
Ho_{I1},
Wo_{c_g_n_k_wos_lengths[I3]},
Z_{I1},
Y_{I1},
X_{b_g_k_c_xs_lengths[I3]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{I1},
WiStride_{a_g_n_c_wis_strides[I3]},
WoStride_{c_g_n_k_wos_strides[I3]},
XStride_{b_g_k_c_xs_strides[I3]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{I1},
ConvStrideW_{conv_filter_strides[I0]},
ConvDilationD_{I1},
ConvDilationH_{I1},
ConvDilationW_{conv_filter_dilations[I0]},
InLeftPadD_{I0},
InLeftPadH_{I0},
InLeftPadW_{input_left_pads[I0]},
InRightPadD_{I0},
InRightPadH_{I0},
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
{
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
if constexpr(SplitN)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Wo_;
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 2, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{a_g_n_c_wis_lengths[I3]},
Wi_{a_g_n_c_wis_lengths[I4]},
Do_{I1},
Ho_{c_g_n_k_wos_lengths[I3]},
Wo_{c_g_n_k_wos_lengths[I4]},
Z_{I1},
Y_{b_g_k_c_xs_lengths[I3]},
X_{b_g_k_c_xs_lengths[I4]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{a_g_n_c_wis_strides[I3]},
WiStride_{a_g_n_c_wis_strides[I4]},
WoStride_{c_g_n_k_wos_strides[I4]},
XStride_{b_g_k_c_xs_strides[I4]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{conv_filter_strides[I0]},
ConvStrideW_{conv_filter_strides[I1]},
ConvDilationD_{I1},
ConvDilationH_{conv_filter_dilations[I0]},
ConvDilationW_{conv_filter_dilations[I1]},
InLeftPadD_{I0},
InLeftPadH_{input_left_pads[I0]},
InLeftPadW_{input_left_pads[I1]},
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
{
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
if constexpr(SplitN)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Ho_ * Wo_;
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 3, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{a_g_n_c_wis_lengths[I3]},
Hi_{a_g_n_c_wis_lengths[I4]},
Wi_{a_g_n_c_wis_lengths[I5]},
Do_{c_g_n_k_wos_lengths[I3]},
Ho_{c_g_n_k_wos_lengths[I4]},
Wo_{c_g_n_k_wos_lengths[I5]},
Z_{b_g_k_c_xs_lengths[I3]},
Y_{b_g_k_c_xs_lengths[I4]},
X_{b_g_k_c_xs_lengths[I5]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{a_g_n_c_wis_strides[I3]},
HiStride_{a_g_n_c_wis_strides[I4]},
WiStride_{a_g_n_c_wis_strides[I5]},
WoStride_{c_g_n_k_wos_strides[I5]},
XStride_{b_g_k_c_xs_strides[I5]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{conv_filter_strides[I0]},
ConvStrideH_{conv_filter_strides[I1]},
ConvStrideW_{conv_filter_strides[I2]},
ConvDilationD_{conv_filter_dilations[I0]},
ConvDilationH_{conv_filter_dilations[I1]},
ConvDilationW_{conv_filter_dilations[I2]},
InLeftPadD_{input_left_pads[I0]},
InLeftPadH_{input_left_pads[I1]},
InLeftPadW_{input_left_pads[I2]},
InRightPadD_{input_right_pads[I0]},
InRightPadH_{input_right_pads[I1]},
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
{
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
if constexpr(SplitN)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
}
template <typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
static auto
MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
return wei_gemmn_gemmk_desc;
}
template <
typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
{
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
return out_gemmm_gemmn_desc;
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
{
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
return out_gemmm_gemmn_desc;
}
// Overloaded functions for hipRTC purposes
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
__host__ __device__ auto MakeADescriptor_M_K() const
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
if constexpr(NumGroupsToMerge == 1)
{
return in_gemmm_gemmk_desc;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
}
......@@ -739,126 +519,229 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
__host__ __device__ auto MakeADescriptor_M_K() const
{
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
return transform_tensor_descriptor(
in_n_ho_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return in_gemmm_gemmk_desc;
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5>{},
Sequence<6>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
}
......@@ -868,149 +751,293 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
__host__ __device__ auto MakeADescriptor_M_K() const
{
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{},
Sequence<8>{}));
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
}
......@@ -1019,20 +1046,53 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
__host__ __device__ auto MakeBDescriptor_N_K() const
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
using FilterSizeNumType =
std::conditional_t<NDimSpatial == 1,
Number<3>,
std::conditional_t<NDimSpatial == 2, Number<9>, Number<27>>>;
const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
}
else
{
return wei_gemmn_gemmk_desc;
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(FilterSizeNumType{})),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
}
else
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(ZYX_ * C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
}
template <
......@@ -1044,26 +1104,14 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
__host__ __device__ auto MakeBDescriptor_N_K() const
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -1075,23 +1123,14 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
......@@ -1099,45 +1138,82 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
return out_gemmm_gemmn_desc;
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NDoHoWo_),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
return out_gemmm_gemmn_desc;
}
public:
index_t N_;
private:
const index_t Di_, Hi_, Wi_;
const index_t Do_, Ho_, Wo_;
const index_t Z_, Y_, X_;
const index_t K_, C_;
const index_t DiStride_, HiStride_, WiStride_;
const index_t WoStride_;
const index_t XStride_;
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
const index_t NStrideTensorA_;
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
const index_t ZYX_;
index_t NDoHoWo_;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
......@@ -1149,26 +1225,22 @@ struct TransformConv
template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization>
auto
transform_func(ck::Array<index_t, NDimSpatial + 3> out_lengths,
ck::Array<index_t, NDimSpatial + 3> out_strides,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
transform_func(TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
{
if(NDimSpatial == 2)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>(out_lengths,
out_strides);
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
}
else if(NDimSpatial == 3)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>(out_lengths,
out_strides);
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
}
else if(NDimSpatial == 1)
{
return conv_fwd_to_gemm.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>(
out_lengths, out_strides);
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
}
}
};
......
......@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
......@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
......@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
......@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
ignore = reg_idx;
#endif
}
};
......
......@@ -839,7 +839,7 @@ inline __device__ T rcp(T x)
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
return ck::type_convert<T>(__ocml_exp_f32(ck::type_convert<float>(x)));
};
template <>
......@@ -851,7 +851,7 @@ inline __device__ half_t exp<half_t>(half_t x)
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
return __ocml_exp_f32(x);
};
template <>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -52,11 +52,19 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ + b_);
}
};
struct SquaredAdd
......@@ -104,11 +112,19 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, half_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<bhalf_t>(a_ * b_);
}
};
struct Max
......
......@@ -54,233 +54,318 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
} // namespace impl
// TODO: glc/slc/...
template <index_t bytes>
template <index_t bytes, bool pre_nop = false>
struct buffer_load;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast"
// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type
// (exp_vector_type(xxx))
template <>
struct buffer_load<16>
template <bool pre_nop>
struct buffer_load<16, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<8>
template <bool pre_nop>
struct buffer_load<8, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<4>
template <bool pre_nop>
struct buffer_load<4, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<2>
template <bool pre_nop>
struct buffer_load<2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <>
struct buffer_load<1>
template <bool pre_nop>
struct buffer_load<1, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 0)
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
template <index_t bytes>
template <index_t bytes, bool pre_nop = false>
struct buffer_load_if;
template <>
struct buffer_load_if<16>
template <bool pre_nop>
struct buffer_load_if<16, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<8>
template <bool pre_nop>
struct buffer_load_if<8, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<4>
template <bool pre_nop>
struct buffer_load_if<4, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<2>
template <bool pre_nop>
struct buffer_load_if<2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
template <>
struct buffer_load_if<1>
template <bool pre_nop>
struct buffer_load_if<1, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
else
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec)
: "memory");
}
};
#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast"
......@@ -294,17 +379,16 @@ struct buffer_store<16>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
asm volatile(
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
......@@ -315,17 +399,16 @@ struct buffer_store<8>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
asm volatile(
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
......@@ -336,17 +419,16 @@ struct buffer_store<4>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_dword %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
......@@ -357,17 +439,16 @@ struct buffer_store<2>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
asm volatile(
"buffer_store_short %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
......@@ -378,17 +459,16 @@ struct buffer_store<1>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile(
"buffer_store_byte %0, %1, %2, %3 offen offset:%4"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
: "memory");
asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3"
:
: "v"(bit_cast<mbuf_t>(value)), "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
}
};
......@@ -402,21 +482,20 @@ struct buffer_store_if<16>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 16);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
......@@ -431,7 +510,7 @@ struct buffer_store_if<8>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
......@@ -439,14 +518,13 @@ struct buffer_store_if<8>
auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
......@@ -461,21 +539,20 @@ struct buffer_store_if<4>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dword %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_dword %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
......@@ -490,21 +567,20 @@ struct buffer_store_if<2>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 2);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = short;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_short %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_short %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
......@@ -519,21 +595,20 @@ struct buffer_store_if<1>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t s_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_byte %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"buffer_store_byte %0, %1, %2, 0 offen offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(bit_cast<mbuf_t>(value)),
"v"(v_offset),
"s"(res),
"s"(s_offset),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
......@@ -901,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
CK_TILE_DEVICE void async_buffer_load_dword(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0)
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc,
index_t voffset,
index_t /*soffset*/,
index_t ioffset /*max 0xFFF*/,
index_t /*flag*/ = 0,
bool_constant<pre_nop> = {})
{
asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset)
: "memory");
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
else
asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds"
: "=r"(smem) /*dummy dependency for smem*/
: "v"(voffset), "s"(rsrc), "n"(ioffset)
: "memory");
}
CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
......@@ -1223,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t flag = 0)
index_t flag = 0,
bool_constant<pre_nop> = {})
{
constexpr index_t bytes = sizeof(T) * N;
static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16,
......@@ -1237,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
using type = thread_buffer<T, N>;
if constexpr(oob_conditional_check)
{
buffer_load_if<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
buffer_load_if<sizeof(type), pre_nop>{}(dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
}
else
{
buffer_load<sizeof(type)>{}(
dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag);
buffer_load<sizeof(type), pre_nop>{}(dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
flag,
bool_constant<pre_nop>{});
}
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0)
index_t src_immediate_addr_offset = 0,
bool_constant<pre_nop> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
async_buffer_load_dword(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset);
async_buffer_load_dword_v(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
0,
bool_constant<pre_nop>{});
}
template <index_t N,
......@@ -1909,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0)
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check>(
dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
is_valid_element,
bool_constant<pre_nop>{});
}
// unfortunately async copy can not make sure invalid data is zero inside LDS
......@@ -1931,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
// buffer_load OOB still working.
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size)
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
......@@ -1943,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0);
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool pre_nop = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
}
// buffer_store requires:
......
......@@ -82,14 +82,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::);
}
CK_TILE_DEVICE void s_nop()
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
asm volatile("s_nop %0" : : "n"(cnt) :);
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(cnt);
#endif
}
......
......@@ -21,6 +21,7 @@
#define __gfx12__
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
......@@ -147,6 +148,14 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
......
......@@ -331,7 +331,10 @@ bfloat16_t sqrt(bfloat16_t x)
};
CK_TILE_DEVICE
bfloat16_t exp(bfloat16_t x) { return static_cast<bfloat16_t>(__expf(static_cast<float>(x))); };
bfloat16_t exp(bfloat16_t x)
{
return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
};
CK_TILE_DEVICE
bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
......
......@@ -835,7 +835,7 @@ CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__expf(static_cast<float>(x))); };
fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
......@@ -860,7 +860,7 @@ CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__expf(static_cast<float>(x))); };
bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
......
......@@ -374,7 +374,7 @@ half_t sqrt(half_t x)
};
CK_TILE_DEVICE
half_t exp(half_t x) { return static_cast<half_t>(__expf(static_cast<float>(x))); };
half_t exp(half_t x) { return static_cast<half_t>(__ocml_exp_f32(static_cast<float>(x))); };
CK_TILE_DEVICE
half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x))); };
......
......@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __expf(x); };
float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
......
......@@ -69,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::generic;
......@@ -224,25 +226,36 @@ struct buffer_view<address_space_enum::global,
T* p_data_ = nullptr;
BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0};
CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
: p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::global;
......@@ -333,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
......@@ -349,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
dst, p_data_, i, buffer_size_, is_valid_element);
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
......@@ -371,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_);
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
......@@ -627,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::lds;
......@@ -909,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE void init_raw() {}
CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::vgpr;
......
......@@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load(lds_tile);
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
......@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }
CK_TILE_DEVICE void init_raw() {}
WindowLengths window_lengths_;
};
......
......@@ -36,6 +36,8 @@ struct tensor_view
{
}
CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }
CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
......@@ -85,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
}
// X is vector of DataType.
......
......@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template <typename DstrTensors, index_t v>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number<v>)
template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
CK_TILE_DEVICE void
set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
{
constexpr index_t tensor_bytes =
DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType);
if constexpr(v == 0 && tensor_bytes % 4 == 0)
using elem_type = typename DstrTensors::DataType;
constexpr index_t elem_size = sizeof(elem_type);
constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
// # bytes per write = 4
if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto& buffer = dstr_tensor.get_thread_buffer();
static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
if constexpr(elem_size == 1)
{
// # elements per write = 4
constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
buffer[i_write * 4 + 0] = values.x;
buffer[i_write * 4 + 1] = values.y;
buffer[i_write * 4 + 2] = values.z;
buffer[i_write * 4 + 3] = values.w;
}
else if constexpr(elem_size == 2)
{
// # elements per write = 2
constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
buffer[i_write * 2 + 0] = values.x;
buffer[i_write * 2 + 1] = values.y;
}
else if constexpr(elem_size == 4)
{
// # elements per write = 1
constexpr elem_type value = 0;
buffer[i_write] = value;
}
else
{
static_assert(false, "type not supported");
}
});
#else
using dvec_t = array<index_t, tensor_bytes / 4>;
auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
for(auto i = 0; i < tensor.size(); i++)
tensor.get(i) = v;
#endif
}
else
{
tile_elementwise_inout(
[](auto& x) { x = type_convert<typename DstrTensors::DataType, index_t>(v); },
dstr_tensor);
tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
dstr_tensor);
}
}
......
......@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true>
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
bool_constant<oob_conditional_check> = {}) const
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
......@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
......@@ -384,7 +391,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{});
bool_constant<oob_conditional_check>{},
pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -399,12 +407,17 @@ struct tile_window_with_static_distribution
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {}) const
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
......@@ -449,11 +462,17 @@ struct tile_window_with_static_distribution
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto pre_nop_ = [&]() {
if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
return bool_constant<true>{};
else
return bool_constant<false>{};
}();
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord);
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
......@@ -668,6 +687,67 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment