Unverified Commit 9f8ab221 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 755ace59 b4fc4d0b
...@@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatAB, FloatAB,
FloatAB,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle ...@@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize, TileMathThreadGroupSize,
ABDataType, ABDataType,
ABDataType,
FloatGemmAcc, FloatGemmAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
......
...@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen ...@@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_B_K0_M_K1, typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1, typename BGridDesc_B_K0_N_K1,
...@@ -153,8 +154,8 @@ __global__ void ...@@ -153,8 +154,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
...@@ -181,21 +182,22 @@ __global__ void ...@@ -181,21 +182,22 @@ __global__ void
c_element_op, c_element_op,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc; ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc; ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = c_element_op; ignore = c_element_op;
ignore = c_block_cluster_adaptor; ignore = c_block_cluster_adaptor;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -242,7 +244,9 @@ template <index_t BlockSize, ...@@ -242,7 +244,9 @@ template <index_t BlockSize,
bool ABlockLdsExtraM1Wrw = false, bool ABlockLdsExtraM1Wrw = false,
bool BBlockLdsExtraN1Wrw = false, bool BBlockLdsExtraN1Wrw = false,
index_t NumGemmKPrefetchStage = 1, index_t NumGemmKPrefetchStage = 1,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeTypeA = FloatA,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// denorm test fix, required to work around fp16 mfma issue // denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// throughout this file
#if CK_WORKAROUND_DENORM_FIX #if CK_WORKAROUND_DENORM_FIX
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>; using FloatAAdjusted =
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
using FloatBAdjusted =
conditional_t<is_same_v<ComputeTypeB, ck::half_t>, ck::bhalf_t, ComputeTypeB>;
#else #else
using FloatABAdjusted = FloatAB; using FloatAAdjusted = ComputeTypeA;
using FloatBAdjusted = ComputeTypeB;
#endif #endif
// M0/M1/M1Padding // M0/M1/M1Padding
...@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), return math::max((a_block_space_size * sizeof(FloatAAdjusted) +
b_block_space_size * sizeof(FloatBAdjusted)),
c_block_size * sizeof(FloatC)); c_block_size * sizeof(FloatC));
} }
...@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1));
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc,
...@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, MPerBlock, K1>, Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatABAdjusted, FloatAAdjusted,
decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc), decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
Sequence<1, K0PerBlock, NPerBlock, K1>, Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatABAdjusted, FloatBAdjusted,
decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc), decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(K1, MfmaSelector<FloatABAdjusted, MPerXDL, NPerXDL>::selected_mfma.k_per_blk); math::max(K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
.k_per_blk);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted, FloatAAdjusted,
FloatBAdjusted,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -757,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -757,10 +770,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); static_cast<FloatAAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size, static_cast<FloatBAdjusted*>(p_shared) + a_block_space_size,
b_k0_n_k1_block_desc.GetElementSpaceSize()); b_k0_n_k1_block_desc.GetElementSpaceSize());
// gridwise GEMM pipeline // gridwise GEMM pipeline
......
...@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -490,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
} }
__host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
// Argument // Argument
struct Problem struct Problem
...@@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!"); "Invalid tuning param!");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = problem.K / K1Value; const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
return false; return false;
...@@ -426,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -426,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm = using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted, FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
...@@ -571,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -571,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
FloatABAdjusted, FloatABAdjusted,
FloatABAdjusted,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1),
...@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -945,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
}(); }();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{ {
return transform_tensor_descriptor(c_grid_desc_m_n, return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
...@@ -1026,8 +1027,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext ...@@ -1026,8 +1027,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = problem.K / K1; const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
......
...@@ -22,13 +22,19 @@ namespace ck { ...@@ -22,13 +22,19 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap> typename Block2CTileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map) const Block2CTileMap& b2c_map,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...@@ -37,10 +43,13 @@ __global__ void ...@@ -37,10 +43,13 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared), b2c_map); karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
#else #else
ignore = karg; ignore = karg;
ignore = b2c_map; ignore = b2c_map;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename Block2CTileMap> typename Block2CTileMap>
__device__ static void Run(const Argument& karg, __device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const AElementwiseOperation a_element_op = AElementwiseOperation{},
const BElementwiseOperation b_element_op = BElementwiseOperation{},
const CElementwiseOperation c_element_op = CElementwiseOperation{})
{ {
const FloatA* p_a_grid = karg.p_a_grid; const FloatA* p_a_grid = karg.p_a_grid;
const FloatB* p_b_grid = karg.p_b_grid; const FloatB* p_b_grid = karg.p_b_grid;
...@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
const AElementwiseOperation a_element_op = AElementwiseOperation{};
const BElementwiseOperation b_element_op = BElementwiseOperation{};
const CElementwiseOperation c_element_op = CElementwiseOperation{};
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
...@@ -761,7 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -761,7 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize, BlockSize,
ComputeType, ComputeType, // ComputeType A
ComputeType, // ComputeType B
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
......
...@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 ...@@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
......
...@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 ...@@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 ...@@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_block_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1),
......
...@@ -16,6 +16,36 @@ ...@@ -16,6 +16,36 @@
namespace ck { namespace ck {
template <typename InputGridDesc,
typename InputDataType,
typename OutputGridDesc,
typename OutputDataType,
typename Block2ETileMap,
typename GridwiseTensorRearrangeKernel>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_tensor_rearrange(const InputGridDesc in_grid_desc,
const InputDataType* __restrict__ p_in_global,
const OutputGridDesc out_grid_desc,
OutputDataType* __restrict__ p_out_global,
const Block2ETileMap block_2_tile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
GridwiseTensorRearrangeKernel::Run(
in_grid_desc, p_in_global, out_grid_desc, p_out_global, block_2_tile_map);
#else
ignore = in_grid_desc;
ignore = p_in_global;
ignore = out_grid_desc;
ignore = p_out_global;
ignore = block_2_tile_map;
#endif
}
template <typename InputGridDesc, template <typename InputGridDesc,
typename InputDataType, typename InputDataType,
typename OutputGridDesc, typename OutputGridDesc,
...@@ -25,8 +55,9 @@ template <typename InputGridDesc, ...@@ -25,8 +55,9 @@ template <typename InputGridDesc,
index_t KPerBlock, index_t KPerBlock,
typename ThreadClusterLengths, typename ThreadClusterLengths,
index_t ScalarPerVector, index_t ScalarPerVector,
InMemoryDataOperationEnum DstInMemOp,
typename Block2ETileMap> typename Block2ETileMap>
struct GridwiseImageToColumn struct GridwiseTensorRearrange
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -55,27 +86,27 @@ struct GridwiseImageToColumn ...@@ -55,27 +86,27 @@ struct GridwiseImageToColumn
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize()); p_out_global, out_grid_desc.GetElementSpaceSize());
auto copy_global_to_global = ThreadGroupTensorSliceTransfer_v7< auto copy_global_to_global =
ThisThreadBlock, ThreadGroupTensorSliceTransfer_v7<ThisThreadBlock,
Tuple<InputDataType>, Tuple<InputDataType>,
Tuple<OutputDataType>, Tuple<OutputDataType>,
decltype(tie(in_grid_desc)), decltype(tie(in_grid_desc)),
decltype(tie(out_grid_desc)), decltype(tie(out_grid_desc)),
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, Sequence<static_cast<index_t>(DstInMemOp)>,
Sequence<MPerBlock, KPerBlock>, Sequence<MPerBlock, KPerBlock>,
ThreadClusterLengths, ThreadClusterLengths,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
I1, I1,
ScalarPerVector, ScalarPerVector,
Sequence<true>, Sequence<true>,
Sequence<true>>{ Sequence<true>>{
in_grid_desc, in_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
out_grid_desc, out_grid_desc,
make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
copy_global_to_global.Run( copy_global_to_global.Run(
tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf));
......
...@@ -18,9 +18,11 @@ template <typename XDataType, ...@@ -18,9 +18,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,6 +36,7 @@ template <typename XDataType, ...@@ -34,6 +36,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationNaiveVariance_mk_to_mk struct GridwiseNormalizationNaiveVariance_mk_to_mk
{ {
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
reduce::Add, reduce::Add,
true>; true>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// LDS // LDS
...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -152,6 +175,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
mean_square_thread_buf; mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>& StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
var_thread_buf = mean_square_thread_buf; var_thread_buf = mean_square_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>&
inv_std_thread_buf = mean_square_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// E(x), E[x^2], var(x) // E(x), E[x^2], var(x)
// FIXME: Should not hack the transform from deviceOP // FIXME: Should not hack the transform from deviceOP
int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; ComputeDataType reduce_length = type_convert<ComputeDataType>(
x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>(); mean_thread_buf(I) = reduce::Add::template GetIdentityValue<ComputeDataType>();
...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// var(x) = E[x^2] - E[x]^2 // var(x) = E[x^2] - E[x]^2
var_thread_buf(I) = var_thread_buf(I) =
mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I));
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk ...@@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -12,31 +12,42 @@ template <typename GridwiseReduction, ...@@ -12,31 +12,42 @@ template <typename GridwiseReduction,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K,
__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, typename GridDesc_M>
const GridDesc_M_K gamma_grid_desc_m_k, __global__ void
const GridDesc_M_K beta_grid_desc_m_k, kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
index_t num_k_block_tile_iteration, const GridDesc_M_K beta_grid_desc_m_k,
ComputeDataType epsilon, const GridDesc_M_K y_grid_desc_m_k,
const XDataType* const __restrict__ p_x_global, const GridDesc_M save_mean_grid_desc_m,
const GammaDataType* const __restrict__ p_gamma_global, const GridDesc_M save_inv_std_grid_desc_m,
const BetaDataType* const __restrict__ p_beta_global, index_t num_k_block_tile_iteration,
YDataType* const __restrict__ p_y_global, ComputeDataType epsilon,
const YElementwiseOperation y_elementwise_op) const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
beta_grid_desc_m_k, beta_grid_desc_m_k,
y_grid_desc_m_k, y_grid_desc_m_k,
save_mean_grid_desc_m,
save_inv_std_grid_desc_m,
num_k_block_tile_iteration, num_k_block_tile_iteration,
epsilon, epsilon,
p_x_global, p_x_global,
p_gamma_global, p_gamma_global,
p_beta_global, p_beta_global,
p_y_global, p_y_global,
p_save_mean_global,
p_save_inv_std_global,
y_elementwise_op); y_elementwise_op);
}; };
...@@ -44,9 +55,11 @@ template <typename XDataType, ...@@ -44,9 +55,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -60,6 +73,7 @@ template <typename XDataType, ...@@ -60,6 +73,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool UseWelford> bool UseWelford>
auto NormalizationKernelSelector(bool isSweepOnce) auto NormalizationKernelSelector(bool isSweepOnce)
{ {
...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceNaive = using GridwiseNormalizationSweepOnceNaive =
GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType, GridwiseNormalizationNaiveVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -106,15 +125,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
using GridwiseNormalizationGenericWelford = using GridwiseNormalizationGenericWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -128,15 +150,18 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
false>; false>;
using GridwiseNormalizationSweepOnceWelford = using GridwiseNormalizationSweepOnceWelford =
GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
GridDesc_M,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -150,6 +175,7 @@ auto NormalizationKernelSelector(bool isSweepOnce)
BetaSrcVectorSize, BetaSrcVectorSize,
YDstVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
SaveMeanInvStdDstVectorSize,
true>; true>;
if constexpr(UseWelford) if constexpr(UseWelford)
...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericWelford, : kernel_normalization<GridwiseNormalizationGenericWelford,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
else else
{ {
...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) ...@@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce)
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K> GridDesc_M_K,
GridDesc_M>
: kernel_normalization<GridwiseNormalizationGenericNaive, : kernel_normalization<GridwiseNormalizationGenericNaive,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
SaveMeanInvStdDataType,
ComputeDataType, ComputeDataType,
YElementwiseOperation, YElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K,
GridDesc_M>;
} }
} }
......
...@@ -17,11 +17,13 @@ template <typename MeanVarDataType, ...@@ -17,11 +17,13 @@ template <typename MeanVarDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename MeanVarGridDesc_M_KBlock, typename MeanVarGridDesc_M_KBlock,
typename CountGridDesc_M_KBlock, typename CountGridDesc_M_KBlock,
typename XYGammaBetaGridDesc_M_K, typename XYGammaBetaGridDesc_M_K,
typename SaveMeanInvStdGridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -34,7 +36,8 @@ template <typename MeanVarDataType, ...@@ -34,7 +36,8 @@ template <typename MeanVarDataType,
index_t BetaSrcVectorDim, index_t BetaSrcVectorDim,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize> index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize>
struct GridwiseNormalizationSplitK2nd struct GridwiseNormalizationSplitK2nd
{ {
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd ...@@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>; using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
static constexpr auto thread_buffer_desc_m_1 = static constexpr auto thread_buffer_desc_m_1 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1)); make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, I1));
...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd
const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_mean_var_count_iteration, index_t num_k_mean_var_count_iteration,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
index_t k_grid_size, index_t k_grid_size,
...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd ...@@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
// Thread/Block id // Thread/Block id
...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd ...@@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
// VGPR // VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
in_mean_thread_buf; in_mean_thread_buf;
...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd
var_thread_buf; var_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
welford_count_thread_buf; welford_count_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd ...@@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
SaveMeanInvStdGridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_m_cluster_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
// step1: Merge mean and variance // step1: Merge mean and variance
constexpr auto mean_var_count_thread_copy_step_I0_k = constexpr auto mean_var_count_thread_copy_step_I0_k =
make_multi_index(I0, KThreadClusterSize); make_multi_index(I0, KThreadClusterSize);
...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd ...@@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd
BlockwiseWelford::Run( BlockwiseWelford::Run(
mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
inv_std_thread_buf(I) =
type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// step2: normalization // step2: save mean and inverse std for backward (optional)
if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// step3: normalization
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
for(index_t k = 0; k < num_k_block_tile_iteration; ++k) for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd ...@@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd ...@@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -16,9 +16,11 @@ template <typename XDataType, ...@@ -16,9 +16,11 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename SaveMeanInvStdDataType,
typename ComputeDataType, typename ComputeDataType,
typename YElementwiseOperation, typename YElementwiseOperation,
typename GridDesc_M_K, typename GridDesc_M_K,
typename GridDesc_M,
index_t BlockSize, index_t BlockSize,
index_t MThreadClusterSize, index_t MThreadClusterSize,
index_t KThreadClusterSize, index_t KThreadClusterSize,
...@@ -32,6 +34,7 @@ template <typename XDataType, ...@@ -32,6 +34,7 @@ template <typename XDataType,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim, index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
index_t SaveMeanInvStdDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseNormalizationWelfordVariance_mk_to_mk struct GridwiseNormalizationWelfordVariance_mk_to_mk
{ {
...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
"Invalid thread slice sizes and/or save mean and inverse std vector sizes "
"configuration, please check!");
static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == YDstVectorSize);
static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize);
static_assert(XSrcVectorSize == BetaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize);
...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}));
using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}));
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{}))); make_tuple(Number<MThreadSliceSize>{}, Number<XSrcVectorSize>{})));
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
ThreadClusterLengths_M_K, ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder>; ThreadClusterArrangeOrder>;
using PassThroughOp = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k,
const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k,
const GridDesc_M_K& y_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k,
const GridDesc_M& save_mean_grid_desc_m,
const GridDesc_M& save_inv_std_grid_desc_m,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
ComputeDataType epsilon, ComputeDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
const YElementwiseOperation y_elementwise_op) const YElementwiseOperation y_elementwise_op)
{ {
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto x_thread_buf = generate_tuple( auto x_thread_buf = generate_tuple(
[&](auto) { [&](auto) {
return StaticBuffer<AddressSpaceEnum::Vgpr, return StaticBuffer<AddressSpaceEnum::Vgpr,
...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -150,6 +164,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
mean_thread_buf; mean_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true> StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
var_thread_buf; var_thread_buf;
auto& inv_std_thread_buf = var_thread_buf;
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_k_cluster_id * YDstVectorSize), thread_k_cluster_id * YDstVectorSize),
y_elementwise_op); y_elementwise_op);
auto threadwise_mean_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_mean_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
auto threadwise_inv_std_store =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
SaveMeanInvStdDataType,
decltype(thread_buffer_desc_m),
GridDesc_M,
PassThroughOp,
ThreadBufferLengths_M,
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
SaveMeanInvStdDstVectorSize, // ScalarPerVector
InMemoryDataOperationEnum::Set,
1,
true>(
save_inv_std_grid_desc_m,
make_multi_index(block_global_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize),
PassThroughOp{});
constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
constexpr auto thread_copy_bwd_step_m_k = constexpr auto thread_copy_bwd_step_m_k =
make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
auto threadwise_welford = ThreadwiseWelford(); auto threadwise_welford = ThreadwiseWelford();
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = type_convert<ComputeDataType>(1.0f) /
ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
// save mean and inverse std for backward (optional)
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma & beta // gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
int count = threadwise_welford.cur_count_; int count = threadwise_welford.cur_count_;
BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon);
}); });
if(thread_k_cluster_id == 0)
{
if(p_save_mean_global != nullptr)
{
threadwise_mean_store.Run(thread_buffer_desc_m,
make_tuple(I0),
mean_thread_buf,
save_mean_grid_desc_m,
save_mean_global_val_buf);
}
if(p_save_inv_std_global != nullptr)
{
threadwise_inv_std_store.Run(thread_buffer_desc_m,
make_tuple(I0),
inv_std_thread_buf,
save_inv_std_grid_desc_m,
save_inv_std_global_val_buf);
}
}
auto thread_copy_tail_m_k = auto thread_copy_tail_m_k =
(num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k;
...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}); });
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k = constexpr auto offset_m_k =
...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ...@@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
// normalize // normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor; inv_std_thread_buf(iM);
// gamma // gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) = y_thread_buf(iK0)(Number<offset_m_k>{}) =
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp" #include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck { namespace ck {
...@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1
auto src_vector_container = src_vector_type{ auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)}; src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx),
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]);
});
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id) src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>( .template SetAsType<dst_vector_t>(src_data_idx_seq,
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); op_r_v.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
...@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{ {
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford<SliceLengths>{}([&](auto idx) { static_ford<SliceLengths>{}([&](auto idx) {
// convert from SrcData to DstData here dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
dst_thread_scratch_(idx) =
type_convert<DstData>(src_thread_scratch_tuple_[thread_scratch_id][idx]);
}); });
#else #else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<SrcData>>::value && ((is_same<half_t, remove_cvref_t<DstData>>::value &&
is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<SrcData>>::value && (is_same<int8_t, remove_cvref_t<DstData>>::value &&
is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
// each transpose does // each transpose does
...@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto data_idx_seq = generate_sequence_v2( constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{}); [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>; using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from // get DstScalarPerVector # of read-only references to src vectors from
...@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<num_dst_vector>{}); Number<num_dst_vector>{});
// do data transpose // do data transpose
transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}( transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs); src_vector_refs, dst_vector_refs);
}); });
} }
else
static_ford<SliceLengths>{}([&](auto idx) { {
// apply the src elementwise op and convert to DstData under the hood if needed static_ford<SliceLengths>{}([&](auto idx) {
DstData dst_v; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); });
dst_thread_scratch_(idx) = dst_v; }
});
#endif #endif
} }
...@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using SrcThreadScratch =
SrcData, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcScalarPerVector, DstData, // apply data_convert with SrcThreadScratch
decltype(src_thread_scratch_desc_), SrcScalarPerVector,
true>; decltype(src_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
// 6. Does not need to know src_descs and dst_descs at compile-time
// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
//
// Does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
// 2. Pass tensor descritpors by reference (or tuple of references)
// 3. Does not keep reference to tensor descriptor
// 4. Does not construct new tensor coordinate when call Run()
template <typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
typename SliceLengths,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...>
struct ThreadwiseTensorSliceTransfer_v7r2
{
static constexpr auto I0 = Number<0>{};
static constexpr index_t nDim = SliceLengths::Size();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
{
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
Number<Descs::Size()>{});
}
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
// scalar per access on each dim
// FIXME: don't use lambda_scalar_per_access
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! cannot evenly divide");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! cannot evenly divide");
}
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
const Indices& src_slice_origin_idxs)
{
static_for<0, nSrc, 1>{}([&](auto i) {
src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
});
}
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
const Indices& dst_slice_origin_idxs)
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
});
}
template <typename DataTypes, index_t ScalarPerVector>
__device__ static auto generate_vectors()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]);
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(),
is_src_valid);
});
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
{
if constexpr(decltype(element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
// apply pointwise function
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
// get reference to src data
const auto src_data_refs = generate_tie(
// return type should be lvalue
[&](auto iSrc) -> const auto& {
using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
},
Number<nSrc>{});
// get reference to dst data
auto dst_data_refs = generate_tie(
// return type should be lvalue
[&](auto iDst) -> auto& {
using DstData = remove_cvref_t<tuple_element_t<iDst.value, DstDatas>>;
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nSrc, 1>{}([&](auto i) {
move_tensor_coordinate(src_descs[i],
src_coords_(i),
make_tensor_coordinate_step(src_descs[i], forward_step));
});
}
});
// move coordinate back to slice origin (or not)
static_for<0, nSrc, 1>{}([&](auto i) {
if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
{
const auto src_reset_step =
make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep());
move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
}
});
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_[i].GetOffset(),
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
static_for<0, nDst, 1>{}([&](auto i) {
move_tensor_coordinate(dst_descs[i],
dst_coords_(i),
make_tensor_coordinate_step(dst_descs[i], forward_step));
});
}
});
static_for<0, nDst, 1>{}([&](auto i) {
if constexpr(DstResetCoordinateAfterRunFlags::At(i))
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep());
move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
}
});
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename SrcBuffers,
typename DstBuffers,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
DstDescs::Size() == DstBuffers::Size(),
bool> = false>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
if constexpr(num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
Number<ISrc> iSrc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRunFlags::At(iSrc)
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t IDst>
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
Number<IDst> iDst,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRunFlags::At(iDst)
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
};
} // namespace ck
...@@ -31,7 +31,13 @@ enum struct MfmaInstr ...@@ -31,7 +31,13 @@ enum struct MfmaInstr
mfma_i32_16x16x32i8, mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64, mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8, mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8 mfma_f32_16x16x32f8f8,
mfma_f32_32x32x16bf8bf8,
mfma_f32_16x16x32bf8bf8,
mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8,
mfma_f32_32x32x16bf8f8,
mfma_f32_16x16x32bf8f8
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -456,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> ...@@ -456,7 +462,6 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
} }
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8> struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{ {
...@@ -500,12 +505,149 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8> ...@@ -500,12 +505,149 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c); intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
} }
}; };
#endif
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8bf8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8bf8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
struct MfmaSelector struct MfmaSelector
{ {
template <typename base_type_, index_t MPerXdlops_, index_t NPerXdlops_> template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
...@@ -642,7 +784,6 @@ struct MfmaSelector ...@@ -642,7 +784,6 @@ struct MfmaSelector
} }
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() static constexpr auto GetMfma<f8_t, 32, 32>()
{ {
...@@ -654,9 +795,45 @@ struct MfmaSelector ...@@ -654,9 +795,45 @@ struct MfmaSelector
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
#endif
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{}; template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
{ {
...@@ -703,7 +880,8 @@ template <typename base_type, ...@@ -703,7 +880,8 @@ template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
index_t KPack, index_t KPack,
bool TransposeC = false> typename additional_type = base_type,
bool TransposeC = false>
struct XdlopsGemm struct XdlopsGemm
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -854,14 +1032,14 @@ struct XdlopsGemm ...@@ -854,14 +1032,14 @@ struct XdlopsGemm
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{ {
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value || static_assert(
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value || is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, int8_t>::value is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
#if defined CK_ENABLE_FP8 is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value ||
|| is_same<base_type, f8_t>::value is_same<base_type, bf8_t>::value ||
#endif (is_same<base_type, f8_t>::value && is_same<additional_type, bf8_t>::value) ||
, (is_same<base_type, bf8_t>::value && is_same<additional_type, f8_t>::value),
"base base_type must be double, float, half, bfloat16, and int8_t!"); "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC) if constexpr(!TransposeC)
...@@ -957,7 +1135,7 @@ struct XdlopsGemm ...@@ -957,7 +1135,7 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
} }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops>{}; static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
...@@ -20,348 +20,13 @@ struct TransformConvFwdToGemm ...@@ -20,348 +20,13 @@ struct TransformConvFwdToGemm
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
is_same_v<ALayout, tensor_layout::convolution::GNWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<NDimSpatial == 3 &&
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>,
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
const auto in_n_di_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const auto in_n_di_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 && typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> || (is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC>), is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false> bool>::type = false>
static auto static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
...@@ -473,7 +138,8 @@ struct TransformConvFwdToGemm ...@@ -473,7 +138,8 @@ struct TransformConvFwdToGemm
template <typename ALayout, template <typename ALayout,
typename std::enable_if< typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> || NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC>), is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false> bool>::type = false>
static auto static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
...@@ -601,7 +267,8 @@ struct TransformConvFwdToGemm ...@@ -601,7 +267,8 @@ struct TransformConvFwdToGemm
template <typename ALayout, template <typename ALayout,
typename std::enable_if< typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> || NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>), is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false> bool>::type = false>
static auto static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......
...@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum ...@@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum
GLC_SLC = 3, GLC_SLC = 3,
}; };
template <typename T, template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
index_t N, __device__ typename vector_type<int8_t, N>::type
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource,
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset,
index_t src_thread_addr_offset, index_t src_wave_addr_offset)
index_t src_wave_addr_offset)
{ {
static_assert( static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) || "wrong! not implemented");
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
if constexpr(is_same<T, double>::value) if constexpr(N == 1)
{ {
// use fp32 load to mimic fp64 load return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
if constexpr(N == 1) src_thread_addr_offset,
{ src_wave_addr_offset,
const float2_t tmp = static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double>(tmp);
}
else if constexpr(N == 2)
{
const float4_t tmp =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<double2_t>(tmp);
}
else if constexpr(N == 4)
{
const float4_t f32_0 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
const float4_t f32_1 =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
vector_type<double, 4> tmp;
tmp.AsType<double2_t>()(Number<0>{}) = bit_cast<double2_t>(f32_0);
tmp.AsType<double2_t>()(Number<1>{}) = bit_cast<double2_t>(f32_1);
return tmp.AsType<double4_t>()(Number<0>{});
}
} }
else if constexpr(is_same<T, float>::value) else if constexpr(N == 2)
{ {
if constexpr(N == 1)
{ int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) = return bit_cast<int8x2_t>(tmp);
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return tmp.AsType<float8_t>()(Number<0>{});
}
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(N == 4)
{ {
if constexpr(N == 1) int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
{
return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, bhalf_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<bhalf8_t>(tmp); return bit_cast<int8x4_t>(tmp);
}
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(N == 8)
{ {
if constexpr(N == 1) int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
{
return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int32x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
return tmp.AsType<int32x8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
return bit_cast<int8x2_t>(tmp); return bit_cast<int8x8_t>(tmp);
#endif }
} else if constexpr(N == 16)
else if constexpr(N == 4) {
{ int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
#else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp);
}
else if constexpr(N == 32)
{
int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
vector_type<int32_t, 8> tmp;
return bit_cast<int8x4_t>(tmp); tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
#endif tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<int8x8_t>()(Number<0>{});
#else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x8_t>(tmp); return bit_cast<int8x32_t>(tmp);
#endif }
} else if constexpr(N == 64)
else if constexpr(N == 16) {
{ int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE src_thread_addr_offset,
vector_type<int8_t, 16> tmp; src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<0>{}) = int32x4_t tmp1 =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
int32x4_t tmp2 =
tmp.AsType<int8x4_t>()(Number<1>{}) = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_thread_addr_offset,
src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(int32_t),
src_wave_addr_offset + 4 * sizeof(int8_t), static_cast<index_t>(coherence));
static_cast<index_t>(coherence)); int32x4_t tmp3 =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
tmp.AsType<int8x4_t>()(Number<2>{}) = src_thread_addr_offset,
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, src_wave_addr_offset + 12 * sizeof(int32_t),
src_thread_addr_offset, static_cast<index_t>(coherence));
src_wave_addr_offset + 8 * sizeof(int8_t),
static_cast<index_t>(coherence));
tmp.AsType<int8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
static_cast<index_t>(coherence));
return tmp.AsType<int8x16_t>()(Number<0>{});
#else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp); vector_type<int32_t, 16> tmp;
#endif
} tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
return bit_cast<int8x64_t>(tmp);
} }
} }
template <typename T, template <typename T,
index_t N, index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data, __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
int32x4_t dst_wave_buffer_resource, index_t src_thread_addr_offset,
index_t dst_thread_addr_offset, index_t src_wave_addr_offset)
index_t dst_wave_addr_offset)
{ {
static_assert( static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) || (is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, double>::value) using r_t = typename vector_type<T, N>::type;
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
return bit_cast<r_t>(raw_data);
}
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void
amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
if constexpr(N == 1)
{ {
// use fp32 store to mimic fp64 store llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
if constexpr(N == 1) dst_wave_buffer_resource,
{ dst_thread_addr_offset,
llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast<float2_t>(src_thread_data), dst_wave_addr_offset,
dst_wave_buffer_resource, static_cast<index_t>(coherence));
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
} }
else if constexpr(is_same<T, float>::value) else if constexpr(N == 2)
{ {
if constexpr(N == 1)
{ llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, dst_wave_buffer_resource,
dst_wave_buffer_resource, dst_thread_addr_offset,
dst_thread_addr_offset, dst_wave_addr_offset,
dst_wave_addr_offset, static_cast<index_t>(coherence));
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(N == 4)
{ {
if constexpr(N == 1) llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
{ dst_wave_buffer_resource,
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, dst_thread_addr_offset,
dst_wave_buffer_resource, dst_wave_addr_offset,
dst_thread_addr_offset, static_cast<index_t>(coherence));
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
} }
else if constexpr(is_same<T, bhalf_t>::value) else if constexpr(N == 8)
{ {
if constexpr(N == 1) llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
{ dst_wave_buffer_resource,
llvm_amdgcn_raw_buffer_store_i16(src_thread_data, dst_thread_addr_offset,
dst_wave_buffer_resource, dst_wave_addr_offset,
dst_thread_addr_offset, static_cast<index_t>(coherence));
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
static_cast<index_t>(coherence));
}
} }
else if constexpr(is_same<T, int32_t>::value) else if constexpr(N == 16)
{ {
if constexpr(N == 1) llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
{ dst_wave_buffer_resource,
llvm_amdgcn_raw_buffer_store_i32(src_thread_data, dst_thread_addr_offset,
dst_wave_buffer_resource, dst_wave_addr_offset,
dst_thread_addr_offset, static_cast<index_t>(coherence));
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
} }
else if constexpr(is_same<T, int8_t>::value) else if constexpr(N == 32)
{ {
if constexpr(N == 1) vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
{
llvm_amdgcn_raw_buffer_store_i8(src_thread_data, llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
static_cast<index_t>(coherence)); static_cast<index_t>(coherence));
}
else if constexpr(N == 2) llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
{ dst_wave_buffer_resource,
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE dst_thread_addr_offset,
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, dst_wave_addr_offset + sizeof(int32_t) * 4,
dst_wave_buffer_resource, static_cast<index_t>(coherence));
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_i16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_i32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
} }
else if constexpr(N == 64)
{
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType<int32x4_t>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
}
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset);
} }
template <typename T, index_t N> template <typename T, index_t N>
...@@ -1127,38 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -1127,38 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
if constexpr(is_same<scalar_t, f8_t>::value) src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
#endif
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
#else #else
#if defined CK_ENABLE_FP8
if constexpr(is_same<scalar_t, f8_t>::value) vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
{ src_wave_buffer_resource, src_thread_addr_offset, 0);
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>( return src_thread_element_valid ? tmp : vector_t(0);
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
#endif
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#if defined CK_ENABLE_FP8
}
#endif
#endif #endif
} }
...@@ -1216,41 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t ...@@ -1216,41 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
#if defined CK_ENABLE_FP8 amd_buffer_store_impl<scalar_t, vector_size, coherence>(
if constexpr(is_same<scalar_t, f8_t>::value) src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
{
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
#else #else
if(dst_thread_element_valid) if(dst_thread_element_valid)
{ {
#if defined CK_ENABLE_FP8 amd_buffer_store_impl<scalar_t, vector_size, coherence>(
if constexpr(is_same<scalar_t, f8_t>::value) src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
#endif
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
#if defined CK_ENABLE_FP8
}
#endif
} }
#endif #endif
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP #pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -418,6 +414,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -418,6 +414,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8;
template <>
struct intrin_mfma_f32_32x32x16bf8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif #endif
} // namespace ck }
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8bf8;
template <>
struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8;
template <>
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif #endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8bf8;
template <>
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<bf8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8;
template <>
struct intrin_mfma_f32_32x32x16bf8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf8f8;
template <>
struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<bf8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment