Unverified Commit 9697ad4e authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 1c97db8a 582e31e8
...@@ -109,30 +109,37 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -109,30 +109,37 @@ struct BlockToCTileMap_M00_N0_M01
// Rows of column-vectors // Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range // This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N> template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_M00_N0_M01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
index_t M01 = 8) default;
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size; return M0 * N0;
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -140,8 +147,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -140,8 +147,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{ {
auto block_1d_id = idx_top[I0]; auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
...@@ -209,11 +216,36 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -209,11 +216,36 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private: private:
index_t M_;
index_t N_;
index_t M01_; index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_; };
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
using Parent = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>;
using Parent::I0;
using Parent::I1;
using Parent::Parent;
using Parent::operator=;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: Parent(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return Parent::CalculateGridSize(c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
}; };
// 2D slices of column-vectors in 3D space // 2D slices of column-vectors in 3D space
......
...@@ -15,6 +15,7 @@ namespace ck { ...@@ -15,6 +15,7 @@ namespace ck {
template <typename GridwiseReduction, template <typename GridwiseReduction,
bool OutputIndex, bool OutputIndex,
bool TransformIndexKtoGlobal,
bool HaveIndexInput, bool HaveIndexInput,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
...@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, ...@@ -48,16 +49,17 @@ __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
} }
else else
{ {
GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k, GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
out_grid_desc_m, in_grid_desc_m_k,
in_elementwise_op, out_grid_desc_m,
acc_elementwise_op, in_elementwise_op,
alpha, acc_elementwise_op,
p_in_value_global, alpha,
p_in_index_global, p_in_value_global,
beta, p_in_index_global,
p_out_value_global, beta,
p_out_index_global); p_out_value_global,
p_out_index_global);
}; };
}; };
...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise
reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
}; };
template <bool HaveIndexInput> template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
__device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
const OutGridDesc_M& out_grid_desc_m, const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op, const InElementwiseOperation& in_elementwise_op,
...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise
indexStart += KThreadSliceSize; indexStart += KThreadSliceSize;
reducedLength += KThreadSliceSize; reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength); } while(reducedLength < toReduceLength);
if constexpr(TransformIndexKtoGlobal)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto coord = make_tensor_coordinate(
in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
accu_index_buf(I)));
accu_index_buf(I) = coord.GetOffset();
});
}
}; };
// for indiced operation, acc_elementwise_op shoud do nothing // for indiced operation, acc_elementwise_op shoud do nothing
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
typename CGridBlockCluster_BlockId_To_GM10_GN10,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs,
typename BM10BN10ThreadClusterBN10Xs,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
{
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
{
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
}
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
BM1PerThreadBM11>{};
constexpr auto BN1 =
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
}
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_grid_block_cluster_blockid_to_gm10_gn10;
}
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, igm10, 0, 0),
a_block_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, ign10, 0, 0),
b_block_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_gk0_bm_gk1),
decltype(b_block_desc_gk0_bn_gk1),
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
BM1PerThreadBM11,
BN1PerThreadBN11>{};
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t gk0_block_on_grid = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
gk0_block_on_grid += 2 * GK0PerBlock;
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
Sequence<1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_multi_index(igm10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
ign10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K == b_k_n_grid_desc.GetLength(I0)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
{
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto cblockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m0_m1_grid_desc),
decltype(a_k_m0_m1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k_m0_m1_grid_desc,
make_multi_index(0, im0, 0),
a_k_m0_m1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n0_n1_grid_desc),
decltype(b_k_n0_n1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k_n0_n1_grid_desc,
make_multi_index(0, in0, 0),
b_k_n0_n1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks,
typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto E = EPerBlock * 3 * 3;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
1,
true>(
b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainKBlockLoop)
{
index_t e_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_k_n_ho_wo_global_desc,
c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_GRIDWISE_GEMM_V3_HPP
#define CK_GRIDWISE_GEMM_V3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActiv(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3_resize_add(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_d_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid,
p_b_grid,
p_bias_grid,
p_d_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v3_maxpool(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid,
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::ConvBiasActivMaxpool(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<ActivTypeEnum, ActivType>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
typename DGridDesc_K_N_Hx_Wx,
index_t E1_,
index_t E2_,
index_t K2_,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t E1PerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_E2,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks,
typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
static constexpr auto K2 = Number<K2_>{};
static constexpr auto NPerBlock = I1;
static constexpr FloatAcc alpha = 0.3;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(I1, Number<E1>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = math::integer_least_multiple(
a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
const index_t grid_size = K0 * N0 * H0 * W0;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0)
{
const bool has_main_e0_block_loop = E0 > 1;
return has_main_e0_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop()
{
const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1;
return has_main_e1_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop()
{
const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0;
return has_double_tail_e1_block_loop;
}
__host__ __device__ static constexpr auto
MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc)
{
const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0);
const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor(
a_e0_e1_k_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_e0_e1_k0_k1_e2_grid_desc;
}
__host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(
const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc)
{
const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0);
// const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1);
const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2);
const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3);
const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4);
// const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5);
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc,
make_tuple(make_pass_through_transform(E0),
make_pass_through_transform(E1),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3, 4, 5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Ho / (H1 * H2);
const auto W2 = Number<WoPerThread>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wo / (W1 * W2);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor(
c_k_n_ho_wo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto H2 = Number<HoPerThread / 2>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Number<Hx / (H1 * H2)>{};
const auto W2 = Number<WoPerThread / 2>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Number<Wx / (W1 * W2)>{};
#else
const auto H2 = HoPerThread / 2;
const auto H1 = HoPerBlock / HoPerThread;
const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread / 2;
const auto W1 = WoPerBlock / WoPerThread;
const auto W0 = Wx / (W1 * W2);
#endif
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = Number<HoPerThread * 2>{};
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto W2 = Number<WoPerThread * 2>{};
const auto W1 = Number<WoPerBlock / WoPerThread>{};
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto H0 = Number<Hx / (H1 * H2)>{};
const auto W0 = Number<Wx / (W1 * W2)>{};
#else
const auto H0 = Hx / (H1 * H2);
const auto W0 = Wx / (W1 * W2);
#endif
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0);
const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1);
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR
const auto K0 = Number<K / KPerBlock>{};
const auto N0 = Number<N / NPerBlock>{};
const auto H0 = Number<Ho / HoPerBlock>{};
const auto W0 = Number<Wo / WoPerBlock>{};
#else
const auto K0 = K / KPerBlock;
const auto N0 = N / NPerBlock;
const auto H0 = Ho / HoPerBlock;
const auto W0 = Wo / WoPerBlock;
#endif
const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
return cblockid_to_k_n_ho_wo_block_cluster_adaptor;
}
// using AGridDesc_E0_E1_K0_K1_E2 =
// decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
// using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
// decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
// using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
// decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
// using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
// decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
template <typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0);
const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1);
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
}
__host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor()
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, I1, Number<HoPerThread>{}, Number<WoPerThread>{}));
return c_k1_n_h2_w2_thread_gemm_desc;
}
// using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor());
__host__ __device__ static constexpr auto GetBlockWiseGemm()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
I1,
Number<HoPerBlock>{},
Number<WoPerBlock>{},
Number<E2>{}));
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_h_w_e2_block_gemm_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc),
EPerThread,
K2>{};
return blockwise_gemm;
}
__device__ static constexpr auto GetCThreadIndex()
{
auto blockwise_gemm = GetBlockWiseGemm();
auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
return c_thread_mtx_index;
};
__device__ static constexpr auto GetCBlockIndex(
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor)
{
const auto c_k_n_h_w_block_cluster_idx =
cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
return c_k_n_h_w_block_cluster_idx;
}
template <typename BiasGlobalBuff,
typename CThreadBuff,
typename CBlockIndex,
typename CThreadIndex,
typename BiasGridDesc_K0_K1,
typename CThreadDesc_K1_N_H2_W2>
__device__ static void BiasOp(BiasGlobalBuff& bias_global_buf,
CThreadBuff& c_thread_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc,
const CThreadDesc_K1_N_H2_W2&)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const auto k_thread_id = c_thread_idx[I0];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
bias_k0_k1_thread_desc.GetElementSpaceSize(),
true>
bias_thread_buf;
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
auto bias_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatC,
FloatC,
decltype(bias_k0_k1_grid_desc),
decltype(bias_k0_k1_thread_desc),
Sequence<I1, Number<KPerThread>{}>,
Sequence<0, 1>,
1,
CThreadTransferDstScalarPerVector,
false,
true>(
bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global));
constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple(
make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{}));
bias_threadwise_transfer.Run(bias_k0_k1_grid_desc,
bias_global_buf,
bias_k0_k1_thread_desc,
make_tuple(I0, I0),
bias_thread_buf,
bias_k0_k1_global_tensor_step_hacks);
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread, 1>{}([&](auto hi) {
static_for<0, WoPerThread, 1>{}([&](auto wi) {
constexpr index_t c_offset =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi));
c_thread_buf(Number<c_offset>{}) =
c_thread_buf[Number<c_offset>{}] + bias_thread_buf[ki];
});
});
});
}
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, ActivTypeEnum activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf,
const CThreadDesc_K1_N_H2_W2&,
integral_constant<ActivTypeEnum, activ_type_>)
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type_ == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
}
else if constexpr(activ_type_ == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
template <typename CThreadBuff,
typename CGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__device__ static void
WriteOut(const CThreadBuff& c_thread_buf,
CGlobalBuff& c_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
// hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global
// tensor
constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{}));
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread, I1, I1, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_global_buf,
c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks);
}
template <typename CThreadBuff,
typename DGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CThreadDesc_K1_N_H2_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>
__device__ static void
MaxPool(const CThreadBuff& c_thread_buf,
DGlobalBuff& d_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CThreadDesc_K1_N_H2_W2&,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, "");
constexpr auto HoPerThread_2 = HoPerThread / 2;
constexpr auto WoPerThread_2 = WoPerThread / 2;
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread_2>{},
I1,
I1,
Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread_2, 1>{}([&](auto hi) {
static_for<0, WoPerThread_2, 1>{}([&](auto wi) {
constexpr index_t d_offset =
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi));
constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2));
constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2 + 1));
constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2));
constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1));
d_thread_buf(Number<d_offset>{}) = c_thread_buf[Number<c_offset_0>{}];
d_thread_buf(Number<d_offset>{}) =
fmaxf(c_thread_buf[Number<c_offset_1>{}], d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) =
fmaxf(c_thread_buf[Number<c_offset_2>{}], d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) =
fmax(c_thread_buf[Number<c_offset_3>{}], d_thread_buf(Number<d_offset>{}));
});
});
});
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread_2, I1, I1, WoPerThread_2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum::Set,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
template <typename CThreadBuff,
typename DGlobalBuff,
typename CBlockIndex,
typename CThreadIndex,
typename CThreadDesc_K1_N_H2_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>
__device__ static void
ResizeAdd(const CThreadBuff& c_thread_buf,
DGlobalBuff& d_global_buf,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const CThreadDesc_K1_N_H2_W2&,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
const auto k_thread_id = c_thread_idx[I0];
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThreadx2>{},
I1,
I1,
Number<WoPerThreadx2>{}));
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
d_thread_buf(Number<d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
});
});
});
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum::Add,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
template <typename AGlobalBuff,
typename BGlobalBuff,
typename CThreadBuff,
typename CBlockIndex,
typename CThreadIndex,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CThreadDesc_K1_N_H2_W2,
bool HasMainE0BlockLoop>
__device__ static void
GemmOp(const AGlobalBuff& a_global_buf,
const BGlobalBuff& b_global_buf,
CThreadBuff& c_thread_buf,
FloatAB* __restrict__ p_shared_block,
const CBlockIndex& c_block_idx,
const CThreadIndex& c_thread_idx,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CThreadDesc_K1_N_H2_W2&,
integral_constant<bool, HasMainE0BlockLoop>)
{
constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop();
constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop();
// const auto c_k_n_h_w_block_cluster_idx =
// GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
// cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex(
// make_multi_index(get_block_1d_id()));
const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]);
const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]);
const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]);
const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]);
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E1PerBlock>{}, Number<KPerBlock>{}, Number<E2>{}), max_lds_align);
constexpr auto b_e1_n_h_w_e2_block_gemm_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<E1PerBlock>{},
I1,
Number<HoPerBlock>{},
Number<WoPerBlock>{},
Number<E2>{}));
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e1_k1_e2_block_gemm_desc),
decltype(b_e1_n_h_w_e2_block_gemm_desc),
decltype(c_k1_n_h2_w2_thread_gemm_desc),
EPerThread,
K2>{};
// blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
const auto ho_thread_id = c_thread_idx[I2];
const auto wo_thread_id = c_thread_idx[I3];
constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<I1>{}, Number<E1>{}, I1, Number<KPerBlock>{}, Number<E2>{}),
max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<I1, E1, I1, KPerBlock, E2>,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e0_e1_k0_k1_e2_grid_desc),
decltype(a_e0_e1_k0_k1_e2_block_copy_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorDim,
4,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_E2,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
false>(a_e0_e1_k0_k1_e2_grid_desc,
make_multi_index(0, 0, k_block_work_id, 0, 0),
a_e0_e1_k0_k1_e2_block_copy_desc,
make_multi_index(0, 0, 0, 0, 0));
constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<E1PerBlock>{},
I1,
I1,
I1,
Number<HoPerThread>{},
I1,
I1,
Number<WoPerThread>{},
Number<E2>{}));
auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2<
FloatAB,
FloatAB,
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc),
Sequence<I1, E1PerBlock, I1, I1, I1, HoPerThread, I1, I1, WoPerThread, E2>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
make_multi_index(0,
0,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0,
0));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize());
//// register allocation for output
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
// true>
// c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k1_n_h2_w2_thread_gemm_desc),
Sequence<KPerThread, I1, HoPerThread, WoPerThread>>{}
.Run(c_k1_n_h2_w2_thread_gemm_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto b_thread_slice_copy_step =
make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAB,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
if constexpr(HasMainE0BlockLoop)
{
const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0);
index_t e0_block_data_begin = 0;
do
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
e1_block_data_begin += 2 * E1PerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc,
a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
e0_block_data_begin += 1;
} while(e0_block_data_begin < E0);
}
else
{
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(
a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
e1_block_data_begin += 2 * E1PerBlock;
} while(e1_block_data_begin < E1 - 2 * E1PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_thread_slice_copy_step,
BGlobalMoveSliceWindowStepHacks{});
b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_global_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
}
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void
Conv(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivMaxpool(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
// MaxPool
MaxPool(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop,
ActivTypeEnum ActivType>
__device__ static void ConvBiasActivResizeAdd(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<ActivTypeEnum, ActivType>)
{
static constexpr auto activ_type = integral_constant<ActivTypeEnum, ActivType>{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Resize_Add
ResizeAdd(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
};
} // namespace ck
#endif
...@@ -17,17 +17,25 @@ ...@@ -17,17 +17,25 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm, bool HasMainKBlockLoop>
typename FloatAB, __global__ void
typename FloatC, #if CK_USE_LAUNCH_BOUNDS
typename AElementwiseOperation, __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
typename BElementwiseOperation, #endif
typename CElementwiseOperation, kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
typename AGridDesc_AK0_M_AK1, {
typename BGridDesc_BK0_N_BK1, #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, defined(__gfx940__))
typename Block2CTileMap, __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
bool HasMainKBlockLoop>
GridwiseGemm::template Run<HasMainKBlockLoop>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -35,55 +43,33 @@ __global__ void ...@@ -35,55 +43,33 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op, typename GridwiseGemm::Problem problem)
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
p_b_grid,
p_c_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_element_op; ignore = problem;
ignore = b_element_op;
ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_ctile_map;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename FloatAB, template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -129,35 +115,396 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -129,35 +115,396 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
return CalculateKPadded(K) / AK1Value;
}
else
{
return K / AK1Value;
}
}
__host__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
return CalculateKPadded(K) / BK1Value;
}
else
{
return K / BK1Value;
}
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_floor(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_floor(N, NPerBlock);
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1), make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1), make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
} }
__host__ __device__ static constexpr auto __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -172,14 +519,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -172,14 +519,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -200,36 +547,102 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -200,36 +547,102 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Block2CTileMap> __host__ static constexpr bool CheckValidity(const Problem& problem)
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.M % MPerBlock == 0))
{
return false;
}
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
return false; GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(problem.N % NPerBlock == 0))
{
return false;
}
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
return false; GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(problem.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{
return false;
}
}
// check gridwise gemm pipeline if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
const auto num_k_loop = K / KPerBlock; {
if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return false; if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
return false; return false;
} }
...@@ -238,22 +651,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -238,22 +651,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
} }
__host__ __device__ static constexpr auto template <typename CGridDesc>
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
...@@ -265,33 +673,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -265,33 +673,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop, typename Block2CTileMap> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const Problem& problem)
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -299,7 +700,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -299,7 +700,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -319,7 +726,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -319,7 +726,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -333,7 +740,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -333,7 +740,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -364,7 +771,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -364,7 +771,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -396,8 +803,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -396,8 +803,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
constexpr index_t KPack = math::max( constexpr index_t KPack =
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
...@@ -425,8 +833,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -425,8 +833,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>); static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck { namespace ck {
...@@ -55,6 +55,7 @@ template <index_t BlockSize, ...@@ -55,6 +55,7 @@ template <index_t BlockSize,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec, tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -82,7 +83,9 @@ template <index_t BlockSize, ...@@ -82,7 +83,9 @@ template <index_t BlockSize,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL, index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock> typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -99,8 +102,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -99,8 +102,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static constexpr auto M01 = 1; static constexpr auto M01 = 1;
static constexpr auto N01 = 1; static constexpr auto N01 = 1;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, K1* K0PerBlock};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
const FloatAB* p_a_grid; const FloatAB* p_a_grid;
...@@ -176,12 +186,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -176,12 +186,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// prefer this to be called on host // prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static auto CalculateMPadded(index_t M)
{ {
return (M + MPerBlock - 1) / MPerBlock * MPerBlock; return math::integer_least_multiple(M, MPerBlock);
} }
__host__ __device__ static auto CalculateNPadded(index_t N) __host__ __device__ static auto CalculateNPadded(index_t N)
{ {
return (N + NPerBlock - 1) / NPerBlock * NPerBlock; return math::integer_least_multiple(N, NPerBlock);
} }
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
...@@ -295,8 +305,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -295,8 +305,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
} }
__host__ __device__ static auto __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
...@@ -309,22 +318,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -309,22 +318,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
...@@ -383,7 +377,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -383,7 +377,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
...@@ -391,40 +393,116 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -391,40 +393,116 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
else else
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
else else
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg N (" << karg.N
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
} }
else else
{ {
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
{
#if DEBUG_LOG
std::cout
<< "Arg M (" << karg.M
<< ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL ("
<< CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false; return false;
}
}
const auto num_k_loop = karg.K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
#if DEBUG_LOG
std::cout << "The number of k loops (" << num_k_loop
<< ") value is not supported by GridwiseGemm Pipeline."
<< " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
} }
return true; return true;
...@@ -439,9 +517,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -439,9 +517,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{ {
const bool has_main_k0_block_loop = K0 > K0PerBlock; const index_t num_loop = K0 / K0PerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
return has_main_k0_block_loop;
} }
template <typename CGridDesc> template <typename CGridDesc>
...@@ -490,7 +567,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -490,7 +567,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>(); return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
} }
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>; using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1))>;
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>; using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -507,8 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -507,8 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
...@@ -680,20 +756,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -680,20 +756,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
#if 1
auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
K1>{};
#else
auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -703,9 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -703,9 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
K1>{}; K1,
LoopSched>();
#endif
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
...@@ -761,7 +824,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -761,7 +824,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock; k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock)); } while(k0_block_data_begin < (karg.K0 - K0PerBlock));
} }
// tail // tail
...@@ -772,13 +835,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -772,13 +835,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
#else #else
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
(K0PerBlock * K1)); (K0PerBlock * K1));
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc, a_b_k0_m_k1_block_desc,
a_blockwise_copy, a_blockwise_copy,
...@@ -993,24 +1055,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -993,24 +1055,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
} }
} }
template <typename Layout>
struct LStr
{
static std::string Get() { return ""; }
};
template <>
struct LStr<ck::tensor_layout::gemm::RowMajor>
{
static std::string Get() { return "R"; }
};
template <>
struct LStr<ck::tensor_layout::gemm::ColumnMajor>
{
static std::string Get() { return "C"; }
};
static std::string GetTypeString() static std::string GetTypeString()
{ {
auto str = std::stringstream(); auto str = std::stringstream();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V3R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "static_tensor.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename Dst0Desc,
typename Dst1Desc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseTensorSliceTransfer_v3r3
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{}));
using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{}));
using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r3(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)),
dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc,
const Dst0Desc& dst0_desc,
const Dst1Desc& dst1_desc,
const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx);
dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
// apply SrcElementwiseOperation on src_vector_container
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
src_vector_container.template AsType<SrcData>()(i) =
src_element_op_(src_vector_container.template AsType<SrcData>()[i]);
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ void TransferDataFromSrcThreadScratchToDstThreadScratch()
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert(SrcVectorDim != DstVectorDim, "wrong");
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
// TODO type_convert is not used yet!!!!!
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return src_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here
dst_thread_scratch_(idx) = type_convert<DstData>(src_thread_scratch_[idx]);
});
}
#endif
}
template <typename DstBuffer, typename Dst0Buffer, typename Dst1Buffer>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
const Dst0Desc& dst0_desc,
const Dst0Buffer& dst0_buf,
const Dst1Desc& dst1_desc,
const Dst1Buffer& dst1_buf)
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch();
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make forward steps: dst0
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst0_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst0_desc, forward_step_idx);
},
Number<nDim>{});
// make forward steps: dst1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst1_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst1_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// make backward steps: dst0
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst0_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst0_desc, backward_step_idx);
},
Number<nDim>{});
// make backward steps: dst1
// WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same
// DstScalarPerVector
// TODO: fix this
const auto dst1_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst1_desc, backward_step_idx);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
using dst_vector_type = vector_type_maker_t<DstData, DstScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
// copy data from dst_thread_scratch_ into dst_vector_container
auto dst_vector_container = dst_vector_type{
dst_thread_scratch_.template GetAsType<dst_vector_t>(dst_data_idx_seq)};
// apply DstElementwiseOperation on dst_vector_container
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
dst_vector_container.template AsType<DstData>()(i) =
dst_element_op_(dst_vector_container.template AsType<DstData>()[i]);
});
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
// TODO: BUG: should start at 1
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Dst0Desc dst0_desc,
const Dst1Desc dst1_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step);
move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
SrcScalarPerVector,
decltype(src_thread_scratch_desc_),
true>
src_thread_scratch_;
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(dst_thread_scratch_desc_),
true>
dst_thread_scratch_;
SrcCoord src_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
const DstElementwiseOperation dst_element_op_;
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#include "data_type.hpp"
namespace ck {
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#include <cstddef>
#include <cstdint>
#include <type_traits>
namespace ck {
namespace detail {
template <unsigned Size>
struct get_unsigned_int;
template <>
struct get_unsigned_int<1>
{
using type = uint8_t;
};
template <>
struct get_unsigned_int<2>
{
using type = uint16_t;
};
template <>
struct get_unsigned_int<4>
{
using type = uint32_t;
};
template <unsigned Size>
using get_unsigned_int_t = typename get_unsigned_int<Size>::type;
} // namespace detail
__device__ inline int32_t amd_wave_read_first_lane(int32_t value)
{
return __builtin_amdgcn_readfirstlane(value);
}
template <
typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
__device__ auto amd_wave_read_first_lane(const Object& obj)
{
using Size = unsigned;
constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize)
{
using Sgpr = detail::get_unsigned_int_t<SgprSize>;
*reinterpret_cast<Sgpr*>(to_obj + offset) =
amd_wave_read_first_lane(*reinterpret_cast<const Sgpr*>(from_obj + offset));
}
if constexpr(0 < RemainedSize)
{
using Carrier = detail::get_unsigned_int_t<RemainedSize>;
*reinterpret_cast<Carrier>(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane(
*reinterpret_cast<const Carrier*>(from_obj + CompleteSgprCopyBoundary));
}
/// NOTE: Implicitly start object lifetime. It's better to use std::start_lifetime_at() in this
/// scenario
return *reinterpret_cast<Object*>(to_obj);
}
} // namespace ck
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "ck/utility/debug.hpp" #include "ck/utility/debug.hpp"
#include "ck/utility/amd_buffer_addressing.hpp" #include "ck/utility/amd_buffer_addressing.hpp"
#include "ck/utility/amd_wave_read_first_lane.hpp"
#include "ck/utility/generic_memory_space_atomic.hpp" #include "ck/utility/generic_memory_space_atomic.hpp"
#include "ck/utility/get_id.hpp" #include "ck/utility/get_id.hpp"
#include "ck/utility/thread_group.hpp" #include "ck/utility/thread_group.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_PRINT_HPP
#define CK_PRINT_HPP
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "container_helper.hpp"
#include "sequence.hpp"
namespace ck {
template <typename T>
__host__ __device__ void print_array(const char* s, T a)
{
constexpr index_t nsize = a.Size();
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename C0DataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBias2D : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<C0DataType>& c0_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c0_m_n_{c0_m_n},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
const Tensor<CDataType>& c0_m_n_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBias2D::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType a = 0;
AccDataType b = 0;
AccDataType acc = 0;
for(int k = 0; k < K; ++k)
{
arg.a_element_op_(a, ck::type_convert<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, ck::type_convert<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b;
}
CDataType cast_acc = static_cast<CDataType>(acc);
arg.c_element_op_(arg.c_m_n_(m, n), cast_acc, arg.c0_m_n_(m, n));
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
const Tensor<C0DataType>& c0_m_n,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c0_m_n, c_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBias2D"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivation : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivation::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c, v_acc, static_cast<float>(arg.c0_n_(n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivation"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
b_k_n_{b_k_n},
c_m_n_{c_m_n},
c0_n_{c0_n},
c1_m_n_{c1_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<BDataType>& b_k_n_;
Tensor<CDataType>& c_m_n_;
const Tensor<CDataType>& c0_n_;
const Tensor<CDataType>& c1_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmBiasActivationAdd::Argument;
float Run(const Argument& arg)
{
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0;
for(int k = 0; k < K; ++k)
{
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
float v_c;
arg.c_element_op_(v_c,
v_acc,
static_cast<float>(arg.c0_n_(n)),
static_cast<float>(arg.c1_m_n_(m, n)));
arg.c_m_n_(m, n) = v_c;
};
make_ParallelTensorFunctor(
f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<BDataType>& b_k_n,
Tensor<CDataType>& c_m_n,
const Tensor<CDataType>& c0_n,
const Tensor<CDataType>& c1_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGemmBiasActivationAdd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
struct ReferencePoolingFwd : public device::BaseOperator
{
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in),
out_(out),
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
static_for<0, WindowRank, 1>{}(
[&](auto I) { reduceLength_ *= window_spatial_lengths[I]; });
}
const Tensor<InDataType>& in_;
Tensor<OutDataType>& out_;
Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float RunPooling3dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, do_, ho, wo) = accuVal;
arg.out_indices_(n, c, do_, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_ncdhw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3],
arg.out_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
};
return 0;
}
float RunPooling2dFwd(const Argument& arg)
{
auto elementwise_ops =
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
arg.reduceLength_);
auto in_elementwise_op = std::get<0>(elementwise_ops);
auto acc_elementwise_op = std::get<1>(elementwise_ops);
if constexpr(!OutputIndex)
{
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
}
else
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
ComputeDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
in_elementwise_op(currVal, currVal);
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
}
}
}
acc_elementwise_op(accuVal, accuVal);
arg.out_(n, c, ho, wo) = accuVal;
arg.out_indices_(n, c, ho, wo) = accuIndex;
};
make_ParallelTensorFunctor(f_nchw,
arg.out_.mDesc.GetLengths()[0],
arg.out_.mDesc.GetLengths()[1],
arg.out_.mDesc.GetLengths()[2],
arg.out_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
};
return 0;
}
float Run(const Argument& arg)
{
// TODO - support generic pooling
if constexpr(InOutRank == 5 && WindowRank == 3)
return RunPooling3dFwd(arg);
else if constexpr(InOutRank == 4 && WindowRank == 2)
return RunPooling2dFwd(arg);
else
throw std::runtime_error("Only support pooling3d or pooling2d so far");
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<InDataType>& in,
Tensor<OutDataType>& out,
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads)
{
return Argument{in,
out,
out_indices,
window_spatial_lengths,
window_strides,
in_left_pads,
in_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferencePoolingFwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <cstdlib> #include <vector>
#include <memory>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment