Commit 32850b93 authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Ported xdlops kernels to debug bwdwrw fp32/fp16/bfp16 issue. Verified atleast fwd data fp32 works.

parent 583755a7
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
#define GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
enum struct ImplicitGemmDirection
{
ForwardData,
BackwardData,
BackwardWeight
};
enum struct ImplicitGemmXdlopsKernel
{
KernelFwdWrw = 0,
Kernel1x1 = 1,
};
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc
{
};
template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc, NonVectorizedC>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I4 = Number<4>{};
return WeiDesc{}
.Fold(I1, Number<NonVectorizedC>{})
.Unfold(I2, I4)
.ReorderGivenNew2Old(Sequence<2, 0, 1>{});
}
};
template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc<ImplicitGemmDirection::BackwardWeight, WeiDesc, NonVectorizedC>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
return make_ConstantMergedTensorDescriptor(
desc.Fold(I1, Number<NonVectorizedC>{}).Unfold(I3, I4),
Sequence<2, 3>{},
Sequence<0>{},
Sequence<1>{});
}
};
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmNRepeat,
index_t EPACK,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2_EPACK,
class InBlockCopyClusterLengths_E_N1_B_N2_EPACK,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K_EPACK,
class WeiBlockCopyClusterLengths_E_K_EPACK,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I5 = Number<5>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
// EPACK=1 for float32, =2 for bfloat16, =4 for float16
static_assert(C % EPACK == 0, "C needs to be multiple of vectorized C (EPACK)");
constexpr auto nonVectorizedC = C / EPACK;
constexpr index_t E = nonVectorizedC * Y * X;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr index_t InBlockCopyDstDataPerWrite_EPACK = EPACK;
constexpr index_t WeiBlockCopyDstDataPerWrite_EPACK = EPACK;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo, {2C/4C}]
constexpr auto in_n0_n1_n2_h_w_2cor4c_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Fold(I1, Number<nonVectorizedC>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 3, 5, 6>{})
.ReorderGivenNew2Old(Sequence<0, 1, 2, 4, 5, 3>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
.Fold(I1, Number<nonVectorizedC>{})
.Extract(Sequence<2, 3, 4>{});
// merged tensor descriptor in device memory [E, N1, B, N2, {2E/4E}], src of blockwise
// copy
constexpr auto in_e_n1_b_n2_2eor4e_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_2cor4c_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
Sequence<5>{},
Sequence<8>{});
// memory layout descriptor in LDS [E, N1, B, N2, {2C/4C}], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_2eor4e_block_desc =
make_ConstantTensorDescriptor_aligned(Sequence<EPerBlock, N1, BPerBlock, N2, EPACK>{},
Number<InBlockCopyDstDataPerWrite_EPACK>{});
// this check for GEMM is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_2eor4e_block_desc.GetStride(I1) % (EPACK * GemmDataPerReadB) ==
0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(in_e_n1_b_n2_2eor4e_global_merged_desc),
decltype(in_e_n1_b_n2_2eor4e_block_desc),
decltype(
in_e_n1_b_n2_2eor4e_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2_EPACK,
InBlockCopyClusterLengths_E_N1_B_N2_EPACK,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
4,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_EPACK>(
{0, 0, b_block_data_on_global, 0, 0}, {0, 0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_2eor4e_global_desc =
make_vectorized_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc), nonVectorizedC>{}
.get(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_2eor4e_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock, EPACK>{}, Number<WeiBlockCopyDstDataPerWrite_EPACK>{});
// this check for GEMM is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(wei_e_k_2eor4e_block_desc.GetStride(I1) % (EPACK * GemmDataPerReadA) == 0,
"GemmDataPerReadA alignment requirement is not satisfied");
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_e_k_2eor4e_global_desc),
decltype(wei_e_k_2eor4e_block_desc),
decltype(wei_e_k_2eor4e_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K_EPACK,
WeiBlockCopyClusterLengths_E_K_EPACK,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
2,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_EPACK>(
{0, k_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock ] is in LDS of type float/bfloat16 vec2/ float16 vec4
// b_mtx[EPerBlocl, N1 * BPerBlock * N2 ] is in LDS of type float/bfloat16 vec2/ float16
// vec4
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<EPerBlock>{}, Number<KPerBlock>{});
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<EPerBlock>{}, Number<N1 * BPerBlock * N2>{});
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
EPACK,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t lds_allocation_align = math::lcm(InBlockCopyDstDataPerWrite_EPACK,
WeiBlockCopyDstDataPerWrite_EPACK,
EPACK * GemmDataPerReadA,
EPACK * GemmDataPerReadB);
constexpr index_t in_block_space = math::integer_least_multiple(
in_e_n1_b_n2_2eor4e_block_desc.GetElementSpace(), lds_allocation_align);
constexpr index_t wei_block_space = math::integer_least_multiple(
wei_e_k_2eor4e_block_desc.GetElementSpace(), lds_allocation_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
// hcc compilation error: loop not unrolled: the optimizer was unable to perform the
// requested transformation;
// the transformation might be disabled or specified as part of an unsupported
// transformation
// ordering [-Werror,-Wpass-failed=transform-warning]
//#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global +=
EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_wei_block_now);
const typename vector_type<Float, EPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_in_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_2eor4e_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const typename vector_type<Float, EPACK>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_wei_block_double);
const typename vector_type<Float, EPACK>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_in_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
p_a_block_vec = reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_wei_block_double + wei_block_space);
p_b_block_vec = reinterpret_cast<const typename vector_type<Float, EPACK>::MemoryType*>(
p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type,
7,
1,
1>(make_zero_array<index_t, 8>(), make_zero_array<index_t, 8>())
.Run(p_out_thread, p_out_thread_on_global);
}
}
};
} // namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_FP16_BFP16_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc, // exchanged outside for backward
class WeiGlobalDesc,
class OutGlobalDesc, // exchanged outside for backward
class ConvStrides,
ImplicitGemmDirection Direction,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t N1,
index_t N2,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr bool isForward = Direction == ImplicitGemmDirection::ForwardData;
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0,
"wrong!");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_k_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
constexpr index_t Ho =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I2);
constexpr index_t Wo =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2);
constexpr index_t B = N0 * Ho * Wo;
constexpr index_t E = C;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc_forw =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc_back =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc =
typename std::conditional<isForward,
decltype(in_n0_n1_n2_h_w_global_desc_forw),
decltype(in_n0_n1_n2_h_w_global_desc_back)>::type{};
// batch descritpor for device memory
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Extract(I1);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(in_c_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
Sequence<0>{},
Sequence<2>{},
Sequence<1, 4, 5>{},
Sequence<3>{});
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
"GemmDataPerReadB alignment requirement is not satisfied");
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc_forw = wei_c_k_global_desc;
constexpr auto wei_e_k_global_desc_back =
make_ConstantTensorDescriptor_packed(Sequence<C, K>{});
constexpr auto wei_e_k_global_desc =
typename std::conditional<isForward,
decltype(wei_e_k_global_desc_forw),
decltype(wei_e_k_global_desc_back)>::type{};
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
// sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
0,
"wrong!");
constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
1, // EPACK = 1
decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// copy output: register to global memory
{
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
make_ConstantTensorDescriptor_packed(
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
// output tensor descriptor in register, src of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
// output memory layout descriptor in device memory, dst of threadwise copy
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto out_lengths_new =
Sequence<out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I5),
math::integer_divide_ceil(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I6),
ConvStrides{}.Get(I0)),
math::integer_divide_ceil(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetLength(I7),
ConvStrides{}.Get(I1))>{};
constexpr auto out_strides_new =
Sequence<out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I6) *
ConvStrides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw.GetStride(I7) *
ConvStrides{}.Get(I1)>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back =
make_ConstantTensorDescriptor(out_lengths_new, out_strides_new);
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = typename std::conditional<
isForward,
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_forw),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc_back)>::type{};
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
Sequence<2>{});
// origin of dst in device memory
Float* p_out_thread_on_global =
p_out_global +
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
ThreadwiseGenericTensorSliceCopy_v1r2<
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 8, 1>::type,
7,
1,
1>(make_zero_array<index_t, 8>(), make_zero_array<index_t, 8>())
.Run(p_out_thread, p_out_thread_on_global);
}
}
};
} // namespace ck
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KC1x1_NKHW_LDS_DOUBLE_BUFFER_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
...@@ -8,9 +8,35 @@ ...@@ -8,9 +8,35 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck { namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc>
struct make_WeiDesc
{
};
template <typename WeiDesc>
struct make_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <typename WeiDesc>
struct make_WeiDesc<ImplicitGemmDirection::BackwardWeight, WeiDesc>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -24,8 +50,7 @@ template <index_t GridSize, ...@@ -24,8 +50,7 @@ template <index_t GridSize,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t N1, index_t GemmNRepeat,
index_t N2,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -48,17 +73,19 @@ template <index_t GridSize, ...@@ -48,17 +73,19 @@ template <index_t GridSize,
class WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void __device__ void Run(const Float* const __restrict__ p_in_global,
Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const
Float* const __restrict__ p_out_global) const
{ {
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) % static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0, 0,
...@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2); constexpr index_t N0 = N / (N1 * N2);
...@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
...@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{}) .StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{}) .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
...@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
...@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); make_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
...@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<EPerBlock>{}, Number<KPerBlock>{}, Number<wei_e_k_block_desc.GetStride(I0)>{}); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<EPerBlock>{}, make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
Number<N1 * BPerBlock * N2>{},
Number<in_e_n1_b_n2_block_desc.GetStride(I0)>{});
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{}); Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
1, // EPACK = 1
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
p_in_block_next); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
threadwise_generic_tensor_slice_copy_v1( ThreadwiseGenericTensorSliceCopy_v1r2<
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
p_out_thread, decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
{0, 0, 0, 0, 0, 0, 0, 0}, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, arithmetic_sequence_gen<0, 8, 1>::type,
p_out_thread_on_global, 7,
{0, 0, 0, 0, 0, 0, 0, 0}, 1,
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), 1>(make_zero_array<index_t, 8>(), make_zero_array<index_t, 8>())
arithmetic_sequence_gen<0, 8, 1>::type{}, .Run(p_out_thread, p_out_thread_on_global);
Number<1>{});
} }
} }
}; };
} // namespace ck } // namespace ck
#endif #endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_FP16_BFP16_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc_Xdlops
{
};
template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc_Xdlops<ImplicitGemmDirection::ForwardData, WeiDesc, NonVectorizedC>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I4 = Number<4>{};
return WeiDesc{}
.Fold(I1, Number<NonVectorizedC>{})
.Unfold(I2, I4)
.ReorderGivenNew2Old(Sequence<2, 0, 1>{});
}
};
template <typename WeiDesc, index_t NonVectorizedC>
struct make_vectorized_WeiDesc_Xdlops<ImplicitGemmDirection::BackwardWeight,
WeiDesc,
NonVectorizedC>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
return make_ConstantMergedTensorDescriptor(
desc.Fold(I1, Number<NonVectorizedC>{}).Unfold(I3, I4),
Sequence<2, 3>{},
Sequence<0>{},
Sequence<1>{});
}
};
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t EPack,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
bool EnableXdlops,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t nonVectorizedC = C / EPack;
constexpr index_t E = nonVectorizedC * Y * X;
constexpr index_t B = N * Ho * Wo;
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo, {2C/4C}]
constexpr auto in_n_ho_wo_2cor4c_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I1, Number<nonVectorizedC>{})
.Extract(Sequence<0, 1, 3, 4>{})
.ReorderGivenNew2Old(Sequence<0, 2, 3, 1>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Fold(I1, Number<nonVectorizedC>{})
.Extract(Sequence<2, 3, 4>{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n_ho_wo_2cor4c_global_desc),
Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{},
Sequence<6>{});
// memory layout descriptor in LDS [E, B, 2Cor4C], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, BPerBlock, EPack>{},
Number<math::lcm(InBlockCopyDataPerAccess_B, GemmDataPerReadB, EPack)>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1, // Src dim to be read in vector form (B dimension)
2, // Dst dim to be written in vector form (EPack dimension)
InBlockCopyDataPerAccess_B, // Src dim vector len
InBlockCopyDataPerAccess_B>( // Dst dim vector len
{0, b_block_data_on_global, 0},
{0, 0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
make_vectorized_WeiDesc_Xdlops<conv_dir,
decltype(wei_k_c_y_x_global_desc),
nonVectorizedC>{}
.get(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock, EPack>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA, EPack)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0, // Src dim to be read in vector form (E dimension)
2, // Dst dim to be written in vector form (EPack dimension)
WeiBlockCopySrcDataPerRead_E, // Src dim vector len
WeiBlockCopyDstDataPerWrite_K>( // Dst dim vector len
{0, k_block_data_on_global, 0},
{0, 0, 0});
// GEMM definition
constexpr auto a_e_k_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<EPerBlock>{}, Number<KPerBlock>{});
constexpr auto b_e_b_block_mtx_desc =
make_ConstantMatrixDescriptor_packed(Number<EPerBlock>{}, Number<BPerBlock>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(mfma_info<Float>{}),
EnableXdlops,
GemmMPerWave,
GemmNPerWave,
GemmMWaves,
GemmNWaves,
GemmDataPerReadA,
GemmDataPerReadB>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB,
EPack);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_out_thread);
// static_if<EnableXdlops>{}(
// [&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_wei_block_now);
const typename vector_type<Float, EPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_in_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const typename vector_type<Float, EPack>::MemoryType* p_a_block_vec =
reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_wei_block_double);
const typename vector_type<Float, EPack>::MemoryType* p_b_block_vec =
reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_in_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
p_a_block_vec = reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_wei_block_double + wei_block_space);
p_b_block_vec = reinterpret_cast<const typename vector_type<Float, EPack>::MemoryType*>(
p_in_block_double + in_block_space);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_out_thread);
}
// load data from xldop_acc_regs
// static_if<EnableXdlops>{}([&](auto) {
// gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
// });
// copy output: register to global memory
{
constexpr index_t K2 = blockwise_gemm.OutputLayout.M2;
constexpr index_t K1 = blockwise_gemm.OutputLayout.M1;
constexpr index_t K0 = blockwise_gemm.OutputLayout.M0;
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr auto out_k0_k1_k2_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<0, 4, 5>{});
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_ConstantTensorDescriptor_packed(Sequence<K2, 1, K0, 1>{});
using OutThreadCopySliceLengths = Sequence<K2, 1, K0, 1>;
constexpr index_t NumKPerBlk = out_k0_k1_k2_b_thread_desc.GetElementSpace();
constexpr index_t NumBlks = GemmMPerWave / NumKPerBlk;
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type,
3, // Src dim to be read in vector form (B dimension)
3, // Dst dim to be written in vector form (B dimension)
OutThreadCopyDataPerAccess_B, // Src dim vector len
OutThreadCopyDataPerAccess_B>( // Dst dim vector len
{0, 0, 0, 0},
{k_thread_data_on_global / (K0 * K1),
k_thread_data_on_global % (K0 * K1) / K0,
k_thread_data_on_global % K0,
b_thread_data_on_global});
threadwise_out_copy.Run(p_out_thread + i * NumKPerBlk, p_out_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
ImplicitGemmDirection Direction,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
bool EnableXdlops,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_k_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr bool isForward = Direction == ImplicitGemmDirection::ForwardData;
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I2);
constexpr index_t Wo =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t E = C;
constexpr index_t B = N * Ho * Wo;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc_forw =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
constexpr auto in_n_ho_wo_global_desc_back = in_n_c_h_w_global_desc.Extract(I0, I2, I3);
constexpr auto in_n_ho_wo_global_desc =
typename std::conditional<isForward,
decltype(in_n_ho_wo_global_desc_forw),
decltype(in_n_ho_wo_global_desc_back)>::type{};
// batch descritpor for device memory
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Extract(I1);
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc = make_ConstantMergedTensorDescriptor(
in_c_global_desc.Embed(in_n_ho_wo_global_desc), Sequence<0>{}, Sequence<1, 2, 3>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, BPerBlock>{},
Number<math::lcm(InBlockCopyDataPerAccess_B, GemmDataPerReadB)>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc_forw = wei_c_k_global_desc;
constexpr auto wei_e_k_global_desc_back =
make_ConstantTensorDescriptor_packed(Sequence<C, K>{});
constexpr auto wei_e_k_global_desc =
typename std::conditional<isForward,
decltype(wei_e_k_global_desc_forw),
decltype(wei_e_k_global_desc_back)>::type{};
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(mfma_info<float>{}),
EnableXdlops,
GemmMPerWave,
GemmNPerWave,
GemmMWaves,
GemmNWaves,
GemmDataPerReadA,
GemmDataPerReadB>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_out_thread);
static_if<EnableXdlops>{}(
[&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// load data from xldop_acc_regs
static_if<EnableXdlops>{}([&](auto) {
gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
});
// copy output: register to global memory
{
constexpr index_t K2 = blockwise_gemm.OutputLayout.M2;
constexpr index_t K1 = blockwise_gemm.OutputLayout.M1;
constexpr index_t K0 = blockwise_gemm.OutputLayout.M0;
constexpr auto out_n_k_h_w_global_desc_forw = out_n_k_h_w_global_desc;
constexpr auto out_lengths_back =
Sequence<out_n_k_h_w_global_desc.GetLength(I0),
out_n_k_h_w_global_desc.GetLength(I1),
math::integer_divide_ceil(out_n_k_h_w_global_desc.GetLength(I2),
ConvStrides{}.Get(I0)),
math::integer_divide_ceil(out_n_k_h_w_global_desc.GetLength(I3),
ConvStrides{}.Get(I1))>{};
constexpr auto out_strides_back =
Sequence<out_n_k_h_w_global_desc.GetStride(I0),
out_n_k_h_w_global_desc.GetStride(I1),
out_n_k_h_w_global_desc.GetStride(I2) * ConvStrides{}.Get(I0),
out_n_k_h_w_global_desc.GetStride(I3) * ConvStrides{}.Get(I1)>{};
constexpr auto out_n_k_h_w_global_desc_back =
make_ConstantTensorDescriptor(out_lengths_back, out_strides_back);
constexpr auto out_n_k_h_w_global_desc_new =
typename std::conditional<isForward,
decltype(out_n_k_h_w_global_desc_forw),
decltype(out_n_k_h_w_global_desc_back)>::type{};
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr auto out_k0_k1_k2_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc_new.Fold(I1, Number<K1>{}, Number<K2>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<0, 4, 5>{});
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_ConstantTensorDescriptor_packed(Sequence<K2, 1, K0, 1>{});
using OutThreadCopySliceLengths = Sequence<K2, 1, K0, 1>;
constexpr index_t NumKPerBlk = out_k0_k1_k2_b_thread_desc.GetElementSpace();
constexpr index_t NumBlks = GemmMPerWave / NumKPerBlk;
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / (K0 * K1),
k_thread_data_on_global % (K0 * K1) / K0,
k_thread_data_on_global % K0,
b_thread_data_on_global});
threadwise_out_copy.Run(p_out_thread + i * NumKPerBlk, p_out_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc>
struct make_WeiDesc_Xdlops
{
};
template <typename WeiDesc>
struct make_WeiDesc_Xdlops<ImplicitGemmDirection::ForwardData, WeiDesc>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <typename WeiDesc>
struct make_WeiDesc_Xdlops<ImplicitGemmDirection::BackwardWeight, WeiDesc>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t EPack,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
bool EnableXdlops,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t nonVectorizedC = C / EPack;
constexpr index_t E = nonVectorizedC * Y * X;
constexpr index_t B = N * Ho * Wo;
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc =
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, BPerBlock>{},
Number<math::lcm(InBlockCopyDataPerAccess_B, GemmDataPerReadB)>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(mfma_info<float>{}),
EnableXdlops,
GemmMPerWave,
GemmNPerWave,
GemmMWaves,
GemmNWaves,
GemmDataPerReadA,
GemmDataPerReadB>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_out_thread);
// static_if<EnableXdlops>{}(
// [&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// load data from xldop_acc_regs
// static_if<EnableXdlops>{}([&](auto) {
// gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
// });
// copy output: register to global memory
{
constexpr index_t K2 = blockwise_gemm.OutputLayout.M2;
constexpr index_t K1 = blockwise_gemm.OutputLayout.M1;
constexpr index_t K0 = blockwise_gemm.OutputLayout.M0;
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr auto out_k0_k1_k2_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<0, 4, 5>{});
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_ConstantTensorDescriptor_packed(Sequence<K2, 1, K0, 1>{});
using OutThreadCopySliceLengths = Sequence<K2, 1, K0, 1>;
constexpr index_t NumKPerBlk = out_k0_k1_k2_b_thread_desc.GetElementSpace();
constexpr index_t NumBlks = GemmMPerWave / NumKPerBlk;
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / (K0 * K1),
k_thread_data_on_global % (K0 * K1) / K0,
k_thread_data_on_global % K0,
b_thread_data_on_global});
threadwise_out_copy.Run(p_out_thread + i * NumKPerBlk, p_out_global);
}
}
}
};
} // namespace ck
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck { namespace ck {
...@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor ...@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor
}; };
template <index_t NRow, index_t NCol> template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number<NRow>, Number<NCol>)
{ {
return ConstantMatrixDescriptor<NRow, NCol, NCol>{}; return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
} }
...@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto ...@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
{
using TDesc = ConstantTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <class TDesc> template <class TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{ {
......
...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor ...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return OriginalTensorDesc{}; return OriginalTensorDesc{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>) __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor ...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetLength(Number<IDim>) __host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs); constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
...@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor ...@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetStride(Number<IDim>) __host__ __device__ static constexpr auto GetStride(Number<IDim>)
{ {
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}), static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined"); "wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front(); constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{}); return OriginalTensorDesc::GetStride(Number<idim_original>{});
} }
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths() __host__ __device__ static constexpr auto GetLengths()
{ {
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{}; return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return OriginalTensorDesc::GetElementSize(); return OriginalTensorDesc::GetElementSize();
} }
...@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor ...@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor
return packed_desc.GetMultiIndexFrom1dIndex(id); return packed_desc.GetMultiIndexFrom1dIndex(id);
} }
__host__ __device__ static constexpr auto Pack()
{
constexpr auto lengths = GetLengths();
constexpr auto strides = calculate_tensor_strides_packed(lengths);
return ConstantTensorDescriptor<decltype(lengths), decltype(strides)>{};
}
}; };
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs> template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
......
...@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor ...@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor
return Sequence<IDim>{}; return Sequence<IDim>{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; } __host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
template <index_t I> __host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
__host__ __device__ static constexpr index_t GetLength(Number<I>)
{
return Lengths::Get(Number<I>{});
}
template <index_t I> __host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
__host__ __device__ static constexpr index_t GetStride(Number<I>)
{
return Strides::Get(Number<I>{});
}
struct lambda_AreDimensionsContinuous struct lambda_AreDimensionsContinuous
{ {
...@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor ...@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor
return false; return false;
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{}); return Number<accumulate_on_sequence(
Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
} }
__host__ __device__ static constexpr index_t GetElementSpace() __host__ __device__ static constexpr auto GetElementSpace()
{ {
constexpr index_t element_space_unaligned = accumulate_on_sequence( constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{}); (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return element_space_unaligned; return Number<element_space_unaligned>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor ...@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor
} }
template <index_t... Is> template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>) __host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
{ {
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{}; constexpr auto multi_id = Sequence<Is...>{};
return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{}); return Number<accumulate_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor ...@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{}; return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
} }
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldIntervals...>)
{
return Fold(Number<IDim>{}, Number<FoldIntervals>{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template <index_t FirstUnfoldDim, index_t LastUnfoldDim> template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>) __host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
...@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor ...@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{}; return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
} }
__host__ __device__ static constexpr auto Pack()
{
using packed_strides = decltype(calculate_tensor_strides_packed(Lengths{}));
return ConstantTensorDescriptor<Lengths, packed_strides>{};
}
template <class MapNew2Old> template <class MapNew2Old>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
{ {
...@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor ...@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{}; decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
} }
#if 0 // require sequence_sort, which is not implemented yet
template <class MapOld2New> template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{ {
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})), return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{} decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
} }
#endif
}; };
template <class Lengths> template <class Lengths>
...@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s, ...@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s,
{ {
constexpr index_t ndim = sizeof...(Lengths); constexpr index_t ndim = sizeof...(Lengths);
static_assert(ndim > 0 && ndim <= 10, "wrong!"); static_assert(ndim > 0 && ndim <= 12, "wrong!");
static_if<ndim == 1>{}([&](auto) { static_if<ndim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...); printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
...@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s, ...@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s,
Lengths..., Lengths...,
Strides...); Strides...);
}); });
static_if<ndim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
} }
} // namespace ck } // namespace ck
......
#ifndef CK_TENSOR_COORDINATE_HPP
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace ck {
template <class TensorDesc>
struct NormalTensorCoordinate
{
using type = NormalTensorCoordinate;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
__host__ __device__ constexpr NormalTensorCoordinate(Array<index_t, nDim> tensor_index)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
{
}
template <class... Xs>
__host__ __device__ constexpr NormalTensorCoordinate(Xs... xs)
: NormalTensorCoordinate(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositOrigin()
{
index_t offset_diff = mOffset;
mOffset = 0;
return offset_diff;
}
private:
index_t mOffset;
};
template <class TensorDesc>
struct MergedTensorCoordinate
{
using type = MergedTensorCoordinate;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim =
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__ __device__ constexpr MergedTensorCoordinate(Array<index_t, nDim> tensor_index)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
{
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset
mOffset =
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
template <class... Xs>
__host__ __device__ constexpr MergedTensorCoordinate(Xs... xs)
: MergedTensorCoordinate(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
template <class IDim, class T, bool PositiveDirection>
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{
constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
}
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim,
"wrong! the rank of step size doesn't match with that of tensor coordinate");
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim,
"wrong! the rank of step size doesn't match with that of tensor coordinate");
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
__host__ __device__ static constexpr index_t RepositOrigin() { return 0; }
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us
Array<index_t, nOriginalDim> mOriginalIndex;
Array<index_t, nDim> mPartialOffsets;
// complete offset
index_t mOffset;
};
} // namespace ck
#endif
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM #ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1 #define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 0
#endif #endif
namespace ck { namespace ck {
...@@ -14,6 +14,7 @@ namespace ck { ...@@ -14,6 +14,7 @@ namespace ck {
// if following number are power of 2, index calculation shall be greatly reduced: // if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster // MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <index_t BlockSize, template <index_t BlockSize,
index_t EPack,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
class ThreadMatrixC, class ThreadMatrixC,
...@@ -113,6 +114,151 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -113,6 +114,151 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 4>::MemoryType& a,
const typename vector_type<float, 4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x4(a.x, b, c[0 * NRepeat]);
outerProduct1x4(a.y, b, c[1 * NRepeat]);
outerProduct1x4(a.z, b, c[2 * NRepeat]);
outerProduct1x4(a.w, b, c[3 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 2>::MemoryType& a,
const typename vector_type<float, 4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x4(a.x, b, c[0 * NRepeat]);
outerProduct1x4(a.y, b, c[1 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 4>::MemoryType& a,
const typename vector_type<float, 2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x2(a.x, b, c[0 * NRepeat]);
outerProduct1x2(a.y, b, c[1 * NRepeat]);
outerProduct1x2(a.z, b, c[2 * NRepeat]);
outerProduct1x2(a.w, b, c[3 * NRepeat]);
}
// 1 in specialized template represent pack size. fp32 = 1
template <index_t PACKSIZE>
__device__ void outerProduct(const typename vector_type<float, 2>::MemoryType& a,
const typename vector_type<float, 2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(1 == PACKSIZE, "only packsize of 1 is supported with float datatype!");
constexpr index_t NRepeat = 2;
outerProduct1x2(a.x, b, c[0 * NRepeat]);
outerProduct1x2(a.y, b, c[1 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct(const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
4>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
4>::MemoryType& b,
typename vector_type<float, 4>::MemoryType* c) const
{
static_assert(2 == PACKSIZE || 4 == PACKSIZE,
"only packsize of 2,4 is supported with float datatype!");
constexpr index_t NRepeat = 2;
const typename vector_type<half, PACKSIZE>::MemoryType* reg_a =
reinterpret_cast<const typename vector_type<half, PACKSIZE>::MemoryType*>(&a);
outerProduct1x4Half<PACKSIZE>(reg_a[0], b, c[0 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[1], b, c[1 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[2], b, c[2 * NRepeat]);
outerProduct1x4Half<PACKSIZE>(reg_a[3], b, c[3 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct(const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& b,
typename vector_type<float, 2>::MemoryType* c) const
{
static_assert(2 == PACKSIZE || 4 == PACKSIZE,
"only packsize of 2,4 is supported with float datatype!");
constexpr index_t NRepeat = 2;
const typename vector_type<half, PACKSIZE>::MemoryType* reg_a =
reinterpret_cast<const typename vector_type<half, PACKSIZE>::MemoryType*>(&a);
outerProduct1x2Half<PACKSIZE>(reg_a[0], b, c[0 * NRepeat]);
outerProduct1x2Half<PACKSIZE>(reg_a[1], b, c[1 * NRepeat]);
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct1x4Half(const typename vector_type<half, PACKSIZE>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c) const
{
static_if<PACKSIZE == 4>{}([&](auto) {
outerProduct1x4dot2TwoTimes(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto) {
static_if<PACKSIZE == 2>{}([&](auto) {
outerProduct1x4dot2(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto fwd) {
// not implemented
static_assert(fwd(false), "wrong! packsize = 1 for fp16 is insensible.");
});
});
}
// PACKSIZE for fp16 could be 4 or 2
template <index_t PACKSIZE>
__device__ void
outerProduct1x2Half(const typename vector_type<half, PACKSIZE>::MemoryType& a,
const typename vector_type<typename vector_type<half, PACKSIZE>::MemoryType,
2>::MemoryType& b,
vector_type<float, 2>::MemoryType& c) const
{
static_if<PACKSIZE == 4>{}([&](auto) {
outerProduct1x2dot2TwoTimes(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto) {
static_if<PACKSIZE == 2>{}([&](auto) {
outerProduct1x2dot2(reinterpret_cast<const half2*>(&a),
reinterpret_cast<const half2*>(&b),
reinterpret_cast<float*>(&c));
}).Else([&](auto fwd) {
// not implemented
static_assert(fwd(false), "wrong! packsize = 1 for fp16 is insensible.");
});
});
}
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
...@@ -131,91 +277,60 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -131,91 +277,60 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && static_assert((MPerThreadSubC == 4 || MPerThreadSubC == 2) &&
MPerThread == 8 && NPerThread == 8, (NPerThreadSubC == 4 || NPerThreadSubC == 2) && KPerThreadLoop == 1,
"Run_amd_asm cannot deal with this GEMM shape yet"); "M/NPerThreadSubC wrong!");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read"); static_assert(MPerThread % 4 == 0 && NPerThread % 4 == 0, "M/NPerThread % 4 != 0");
// If A and B datatype is float constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
static_if<std::is_same<FloatA, float>::value && constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
std::is_same<FloatB, float>::value>{}([&](auto) {
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = reinterpret_cast<Float4*>(p_a_thread);
Float4* reg_b = reinterpret_cast<Float4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}).Else([&](auto) { // If A and B datatype is bfloat16/float16
using Half4x4 = vector_type<vector_type<half, 4>, 4>;
using Float4 = vector_type<float, 4>::MemoryType;
Half4x4* reg_a = reinterpret_cast<Half4x4*>(p_a_thread);
Half4x4* reg_b = reinterpret_cast<Half4x4*>(p_b_thread);
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
static_assert(MRepeat == 2 && NRepeat == 2, "M/NRepeat != 2");
using typeA = typename vector_type<FloatA, MPerThreadSubC>::MemoryType;
using typeB = typename vector_type<FloatB, NPerThreadSubC>::MemoryType;
using typeC = typename vector_type<FloatC, NPerThreadSubC>::MemoryType;
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
typeA* reg_a = reinterpret_cast<typeA*>(p_a_thread);
typeB* reg_b = reinterpret_cast<typeB*>(p_b_thread);
typeC* reg_c = reinterpret_cast<typeC*>(p_c_thread);
reg_a[0] = *reinterpret_cast<const typeA*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const typeB*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const typeB*>(&p_b_block[(mMyThreadOffsetB + NPerLevel1Cluster)]);
reg_a[1] =
*reinterpret_cast<const typeA*>(&p_a_block[(mMyThreadOffsetA + MPerLevel1Cluster)]);
outerProduct<EPack>(reg_a[0], reg_b[0], &reg_c[0]);
outerProduct<EPack>(reg_a[0], reg_b[1], &reg_c[1]);
#pragma unroll #pragma unroll
for(index_t k = 1; k < K; ++k) for(index_t k = 1; k < K; ++k)
{ {
reg_a[0] = *reinterpret_cast<const Half4x4*>(&p_a_block[mMyThreadOffsetA + k * M]); reg_a[0] = *reinterpret_cast<const typeA*>(&p_a_block[(mMyThreadOffsetA + k * M)]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct<EPack>(reg_a[1], reg_b[0], &reg_c[NRepeat * MPerThreadSubC]);
reg_b[0] = *reinterpret_cast<const Half4x4*>(&p_b_block[mMyThreadOffsetB + k * N]); reg_b[0] = *reinterpret_cast<const typeB*>(&p_b_block[(mMyThreadOffsetB + k * N)]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct<EPack>(reg_a[1], reg_b[1], &reg_c[NRepeat * MPerThreadSubC + 1]);
reg_b[1] = *reinterpret_cast<const Half4x4*>( reg_b[1] = *reinterpret_cast<const typeB*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]); &p_b_block[(mMyThreadOffsetB + k * N + NPerLevel1Cluster)]);
reg_a[1] = *reinterpret_cast<const Half4x4*>( reg_a[1] = *reinterpret_cast<const typeA*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]); &p_a_block[(mMyThreadOffsetA + k * M + MPerLevel1Cluster)]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct<EPack>(reg_a[0], reg_b[0], &reg_c[0]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); outerProduct<EPack>(reg_a[0], reg_b[1], &reg_c[1]);
} }
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct<EPack>(reg_a[1], reg_b[0], &reg_c[NRepeat * MPerThreadSubC]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct<EPack>(reg_a[1], reg_b[1], &reg_c[NRepeat * MPerThreadSubC + 1]);
});
} }
#endif #endif
...@@ -250,8 +365,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -250,8 +365,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{}); Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()*4]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()*4]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
...@@ -270,10 +385,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -270,10 +385,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
p_a_block + p_a_block +
(a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) + a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA)*4, mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + (a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC))*4, p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(), a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{}); Number<DataPerReadA>{});
} }
...@@ -285,10 +400,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -285,10 +400,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
threadwise_matrix_copy( threadwise_matrix_copy(
b_block_mtx, b_block_mtx,
p_b_block + p_b_block +
(b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) + b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB)*4, mMyThreadOffsetB,
b_thread_mtx, b_thread_mtx,
p_b_thread + (b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC))*4, p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(), b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{}); Number<DataPerReadB>{});
} }
...@@ -306,156 +421,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -306,156 +421,21 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
} }
template <class FloatA, class FloatB, class FloatC>
__device__ void RunRegisterDoubleBuffer_source(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
// register
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
bool even_loop = true;
#pragma unroll
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
// preload next A, B
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
(k_begin + 1) * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
(k_begin + 1) * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread);
}
// last loop
{
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread);
}
}
template <class FloatA, class FloatB, class FloatC> template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block, __device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const FloatC* __restrict__ p_c_thread) const
{ {
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
static_if<std::is_same<FloatA, ushort>::value && std::is_same<FloatB, ushort>::value>{}( // The assembly path doesn't support bfloat16 using asm instructions
[&](auto) { Run_source(p_a_block, p_b_block, p_c_thread); }) #if MIOPEN_USE_BFP16 == 1
.Else([&](auto) { // If A and B datatype is bfloat16/float16
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
});
#else
Run_source(p_a_block, p_b_block, p_c_thread); Run_source(p_a_block, p_b_block, p_c_thread);
#else
Run_amd_asm(p_a_block, p_b_block, p_c_thread);
#endif #endif
#else
Run_source(p_a_block, p_b_block, p_c_thread);
#endif // CK_USE_AMD_INLINE_ASM
} }
}; };
......
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <class input_type>
struct mfma_info
{
};
template <>
struct mfma_info<float>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_blks_wave = 2;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t wave_size = 64;
};
template <>
struct mfma_info<half>
{
static const index_t group_size = 4;
static const index_t num_groups_blk = 4;
static const index_t num_blks_wave = 2;
static const index_t num_regs_blk = group_size * num_groups_blk;
static const index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static const index_t num_threads_blk = 32;
static const index_t m = 32;
static const index_t n = 32;
static const index_t k = 4;
static const index_t wave_size = 64;
};
template <>
struct mfma_info<ushort>
{
static const index_t group_size = 4;
static const index_t num_groups_blk = 4;
static const index_t num_blks_wave = 2;
static const index_t num_regs_blk = group_size * num_groups_blk;
static const index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static const index_t num_threads_blk = 32;
static const index_t m = 32;
static const index_t n = 32;
static const index_t k = 2;
static const index_t wave_size = 64;
};
// emulate xdlops
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info,
class FloatA,
class FloatB,
class FloatC>
__device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread)
{
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
const index_t blk_id = laneId / mfma_info::num_threads_blk;
const index_t lane_b = laneId % mfma_info::num_threads_blk;
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerWave / mfma_info::m; ++b)
{
index_t a_off = k * M + b * mfma_info::m;
index_t b_off = k * N;
// pseudo mfma
for(index_t n = 0; n < mfma_info::num_blks_wave; ++n)
{
index_t output_m = mfma_info::num_regs_blk;
for(index_t m = 0; m < output_m; ++m)
{
index_t aindex = m % mfma_info::group_size + blk_id * mfma_info::group_size +
m / mfma_info::group_size *
(mfma_info::group_size * mfma_info::num_blks_wave) +
a_off; // A is transposed
index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk;
p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] +=
math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex],
p_b_wave[bindex]);
}
}
}
}
}
#if 0
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(const float* const __restrict__ p_a_wave,
const float* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k)
{
float reg_a = p_a_wave[k * M + laneId];
float reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x1f32<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 4)
{
typename vector_type<half, 4>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<half, 4>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x4f16<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 2)
{
typename vector_type<ushort, 2>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<ushort, 2>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x2bf16<MPerWave>(reg_a, reg_b, reg_c);
}
}
#endif
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class mfma_info,
bool EnableXdlops,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
struct OutputLayout_t
{
static constexpr index_t M3 = GemmMPerWave / mfma_info::m;
static constexpr index_t M2 = mfma_info::num_groups_blk;
static constexpr index_t M1 = mfma_info::num_blks_wave;
static constexpr index_t M0 = mfma_info::group_size;
};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
OutputLayout_t OutputLayout;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmNPerWave == 64, "Only support GemmNPerWave == 64 for xdlops");
static_assert(GemmMPerWave == 32 || GemmMPerWave == 64,
"Only support GemmMPerWave == 32 or 64 for xdlops");
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * 64,
"BlockSize != GemmMWaves * GemmNWaves * 64\n");
const index_t waveId = get_thread_local_1d_id() / mfma_info::wave_size;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
// static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M,
// N,
// K,
// GemmMPerWave,
// GemmDataPerReadA,
// GemmDataPerReadB,
// mfma_info>(
// &p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
// }).Else([&](auto) {
WaveWiseGemmMx64<M, N, K, GemmMPerWave, GemmDataPerReadA, GemmDataPerReadB, mfma_info>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
// });
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
const index_t waveId = get_thread_local_1d_id() / mfma_info::wave_size;
const index_t col_i = i % mfma_info::num_blks_wave;
const index_t col = waveId % GemmNWaves * mfma_info::wave_size +
laneId % mfma_info::num_threads_blk +
col_i * mfma_info::num_threads_blk;
const index_t row_i = i / mfma_info::num_blks_wave;
const index_t row = waveId / GemmNWaves * GemmMPerWave +
laneId / mfma_info::num_threads_blk * mfma_info::group_size +
row_i * mfma_info::num_threads_blk;
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
constexpr index_t num_xdlops = GemmMPerWave / mfma_info::m;
return make_ConstantMatrixDescriptor_packed(
Number<mfma_info::num_regs_xdlops * num_xdlops>{}, Number<1>{});
}
};
} // namespace ck
#endif
...@@ -10,12 +10,541 @@ ...@@ -10,12 +10,541 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#endif #endif
#define JOINTCAT(x, y) x##y
#define ASSERT_MSG_ARG1(msg, var1) JOINTCAT(msg, var1)
#define ASSERT_MSG_ARG2(msg, var1, va2) ASSERT_MSG_ARG1(JOINTCAT(msg, var1), var2)
namespace ck { namespace ck {
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst.
// This functions assume each thread is reading and writing a normal (not merged) tensor,
// to simplify index calculations. To satisfy this assumption, the user need to make sure
// that, on a merged dimension that constains multiple original dimensions, the length of
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
// in constructor of BlockwiseGenericTensorSliceCopy_v1
template <index_t BlockSize,
class SrcDesc,
class DstDesc,
class SliceLengths,
class SubLengths,
class ThreadClusterLengths,
class ThreadClusterArrangeOrder,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v1
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static constexpr index_t nOriginalDimSrc =
SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
static constexpr index_t nOriginalDimDst =
DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
// per-thread offset
index_t mThreadSrcOffset;
index_t mThreadDstOffset;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array<index_t, nDim> mThreadSrcPartialOffsets;
Array<index_t, nDim> mThreadDstPartialOffsets;
// multi-id of original tensor
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
__device__ BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_id_begin,
Array<index_t, nDim> dst_block_data_id_begin)
{
// check NDim consistency
static_assert(
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() &&
nDim == ThreadClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(),
"wrong");
// check thread arrange order and read/write access order are valid
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong!");
// thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
// BlockSize
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
// divide work
constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim) {
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into cluster");
});
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// additional check for merged dimension
static_for<0, nDim, 1>{}([&](auto IDim_) {
// src
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_src =
SrcDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
// dst
static_if<DstDesc::ContainMultipleOriginalDimensions(IDim_)>{}([&](auto) {
constexpr auto IDim = decltype(IDim_){};
// on a merged dimension that constains multiple original dimensions,
// the length of the last original dimension need to evenly dividable by its
// sub-length,
// so each thread is effectively reading a normal (not merged) tensor
constexpr auto idim_last_original_dst =
DstDesc::GetContainedOriginalDimensions(IDim).Back();
static_assert(
DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) %
SubLengths::Get(IDim) ==
0,
"wrong!");
// merged dimension should have repeat_lengths = 1
static_assert(repeat_lengths[IDim] == 1,
"wrong! repeat_lengths shoud be 1 on merged dimension");
});
});
// calculate mThreadSrcOffset, mThreadDstOffset
const auto thread_cluster_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
// original multi-id
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
src_block_data_id_begin + thread_data_id_begin);
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
dst_block_data_id_begin + thread_data_id_begin);
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
});
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto dst_partial_original_dims =
DstDesc::GetContainedOriginalDimensions(IDim);
constexpr auto dst_partial_original_desc =
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
});
// complete offset
mThreadSrcOffset = accumulate_on_array(
mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
mThreadDstOffset = accumulate_on_array(
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
__device__ static constexpr auto GetRegisterBufferDescriptor()
{
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
}
__device__ static constexpr index_t GetRegisterBufferSize()
{
return GetRegisterBufferDescriptor().GetElementSpace();
}
template <class TData>
__device__ void RunLoadRegisterBuffer(const TData* __restrict__ p_src,
TData* __restrict__ p_buffer) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims =
thread_sub_tensor_lengths * ThreadClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
constexpr index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
constexpr index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims;
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin);
const index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1
ThreadwiseGenericTensorSliceCopy_v1r2<SrcDesc,
decltype(thread_buffer_desc),
SubLengths,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcDataPerAccess,
1>(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
.Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset);
});
}
template <class TData>
__device__ void RunStoreRegisterBuffer(const TData* __restrict__ p_buffer,
TData* __restrict__ p_dst) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims =
thread_sub_tensor_lengths * ThreadClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{});
constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor();
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
constexpr index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
#else
ford<decltype(repeat_lengths)>{}([&](auto repeat_id) {
const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths;
const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims;
const index_t buffer_offset =
thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin);
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin);
#endif
// By position the origin of the per-thread window at the point, where multi-index
// of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy
// is assuming each thread is copy a noraml (not merged) tensor.
// To satisfy this assumption, the user need to make sure that, on a merged dimension
// that constains multiple original dimensions, the length of the last original
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
// the merged dimension need to be 1. These sanity checks are performed in constructor
// of BlockwiseGenericTensorSliceCopy_v1
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(thread_buffer_desc),
DstDesc,
SubLengths,
DstDimAccessOrder,
DstVectorAccessDim,
1,
DstDataPerAccess>(
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>())
.Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset);
});
}
template <class TData>
__device__ void Run(const TData* __restrict__ p_src, TData* __restrict__ p_dst) const
{
TData p_buffer[GetRegisterBufferSize()];
RunLoadRegisterBuffer(p_src, p_buffer);
RunStoreRegisterBuffer(p_buffer, p_dst);
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are not in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
// calculate new partial original multi-id
auto old_src_partial_original_id =
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
auto new_src_partial_original_id =
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
old_src_partial_original_id, StepSize, direction);
// update "mThreadSrcOriginalMultiId"
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) {
constexpr auto IDimOriginal = src_partial_original_dims[I];
mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim];
const index_t new_src_partial_offset =
src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets(IDim) = new_src_partial_offset;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
}).Else([&](auto) {
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
// It is the user's responsiblity to make sure the slicing window will not be moved out
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) += StepSize;
mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) {
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize;
mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
});
});
}
template <class T, bool PositiveDirection>
__device__ void
MoveSrcSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction);
}
});
}
};
template <index_t BlockSize,
class SrcDesc,
class DstDesc,
class SrcCoordinate,
class DstCoordinate,
class SliceLengths,
class SubLengths,
class ThreadClusterLengths,
class ThreadClusterArrangeOrder,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct BlockwiseGenericTensorSliceCopy_v2
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(SrcCoordinate src_block_slice_origin,
DstCoordinate dst_block_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SubLengths::GetSize() &&
nDim == ThreadClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize(),
"wrong! nDim not consistent");
static_assert(is_same<SliceLengths, decltype(SubLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
const auto thread_cluster_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_id =
reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{});
const auto thread_data_id_begin = data_cluster_id * SubLengths{};
mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin);
mThreadwiseLoad.SetDstSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetSrcSliceOrigin(make_zero_array<index_t, nDim>());
mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin);
}
__device__ static constexpr index_t GetRegisterBufferSize()
{
return RegisterBufferDesc::GetElementSpace();
}
template <class TData>
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
{
mThreadwiseLoad.Run(p_src, p_buffer);
}
template <class TData>
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
{
mThreadwiseStore.Run(p_buffer, p_dst);
}
template <class TData>
__device__ void Run(const TData* p_src, TData* p_dst) const
{
TData p_buffer[GetRegisterBufferSize()];
mThreadwiseLoad.Run(p_src, p_buffer);
mThreadwiseStore.Run(p_buffer, p_dst);
}
template <class T, bool PositiveDirection>
__device__ void
MoveSrcSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction);
}
template <class T, bool PositiveDirection>
__device__ void
MoveDstSlicingWindow(T step_sizes,
integral_constant<bool, PositiveDirection> positive_direction)
{
mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction);
}
private:
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
using ThreadwiseLoad =
ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
RegisterBufferDesc,
SrcCoordinate,
NormalTensorCoordinate<RegisterBufferDesc>,
SubLengths,
SrcDimAccessOrder,
SrcDimAccessOrder,
SrcVectorAccessDim,
SrcVectorAccessDim,
SrcDataPerAccess,
1>;
using ThreadwiseStore =
ThreadwiseGenericTensorSliceCopy_v2r1<RegisterBufferDesc,
DstDesc,
NormalTensorCoordinate<RegisterBufferDesc>,
DstCoordinate,
SubLengths,
DstDimAccessOrder,
DstDimAccessOrder,
DstVectorAccessDim,
DstVectorAccessDim,
1,
DstDataPerAccess>;
ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore;
};
// this will be deprecated
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor // slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst // memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths[...] == 1 on a merged dimension // For now, only support SubLengths[...] == 1 on a merged dimension
...@@ -31,7 +560,7 @@ template <index_t BlockSize, ...@@ -31,7 +560,7 @@ template <index_t BlockSize,
class DstAccessOrder, class DstAccessOrder,
index_t SrcDataPerRead, index_t SrcDataPerRead,
index_t DstDataPerWrite> index_t DstDataPerWrite>
struct BlockwiseGenericTensorSliceCopy_v1 struct BlockwiseGenericTensorSliceCopy_v1_deprecated
{ {
static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
...@@ -59,9 +588,9 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -59,9 +588,9 @@ struct BlockwiseGenericTensorSliceCopy_v1
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId; Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId; Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
__device__ __device__ BlockwiseGenericTensorSliceCopy_v1_deprecated(
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin, Array<index_t, nDim> src_block_data_multi_id_begin,
Array<index_t, nDim> dst_block_data_multi_id_begin) Array<index_t, nDim> dst_block_data_multi_id_begin)
{ {
// check NDim consistency // check NDim consistency
static_assert(nDim == SrcDesc::GetNumOfDimension() && static_assert(nDim == SrcDesc::GetNumOfDimension() &&
...@@ -213,15 +742,16 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -213,15 +742,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin); thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#endif #endif
threadwise_generic_tensor_slice_copy_v1(SrcDesc{}, threadwise_generic_tensor_slice_copy_v1_deprecated(SrcDesc{},
p_src + src_offset + mThreadSrcOffset, p_src + src_offset +
make_zero_array<index_t, nDim>(), mThreadSrcOffset,
thread_tensor_desc, make_zero_array<index_t, nDim>(),
p_clipboard + clipboard_offset, thread_tensor_desc,
make_zero_array<index_t, nDim>(), p_clipboard + clipboard_offset,
thread_sub_tensor_lengths, make_zero_array<index_t, nDim>(),
SrcAccessOrder{}, thread_sub_tensor_lengths,
Number<SrcDataPerRead>{}); SrcAccessOrder{},
Number<SrcDataPerRead>{});
}); });
} }
...@@ -264,15 +794,16 @@ struct BlockwiseGenericTensorSliceCopy_v1 ...@@ -264,15 +794,16 @@ struct BlockwiseGenericTensorSliceCopy_v1
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin); const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif #endif
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc, threadwise_generic_tensor_slice_copy_v1_deprecated(thread_tensor_desc,
p_clipboard + clipboard_offset, p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>(),
DstDesc{}, DstDesc{},
p_dst + dst_offset + mThreadDstOffset, p_dst + dst_offset +
make_zero_array<index_t, nDim>(), mThreadDstOffset,
thread_sub_tensor_lengths, make_zero_array<index_t, nDim>(),
DstAccessOrder{}, thread_sub_tensor_lengths,
Number<DstDataPerWrite>{}); DstAccessOrder{},
Number<DstDataPerWrite>{});
}); });
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "float_types.h" #include "math.hpp"
namespace ck { namespace ck {
...@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
// Depending upon datatype i.e float/half/bfloat16, carry out data movement for(index_t i = 0; i < NRow; ++i)
// in appropriate vectorized form {
// float - 4, half - 4, bfloat16 - 2 for(index_t j = 0; j < NCol; j += DataPerRead)
static_if<std::is_same<Float, float>::value>{}([&](auto) {
using vector_t = typename vector_type<float, DataPerRead>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; j += DataPerRead) const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
{ const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) = *reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
} }
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
}
}
});
});
} }
template <class MatrixA, template <class MatrixA,
...@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA,
constexpr index_t N = c_mtx.NCol(); constexpr index_t N = c_mtx.NCol();
constexpr index_t K = a_mtx.NRow(); // A is transposed constexpr index_t K = a_mtx.NRow(); // A is transposed
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(index_t i = 0; i < M; ++i) for(index_t i = 0; i < M; ++i)
...@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j); const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j); const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
static_if<std::is_same<FloatA, float>::value>{}([&](auto) { p_c_thread[cindex] += math::inner_product_with_conversion<FloatC>{}(
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) * p_a_thread[aindex], p_b_thread[bindex]);
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
}
p_c_thread[cindex] += acc;
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
}
p_c_thread[cindex] += acc;
});
});
} }
} }
} }
......
...@@ -4,123 +4,802 @@ ...@@ -4,123 +4,802 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor.hpp"
#include "tensor_coordinate.hpp"
#include "float_types.h" #include "float_types.h"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 #ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#endif
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
#endif #endif
namespace ck { namespace ck {
// user need to make sure alignment requirement is satisfied when setting DataPerAccesss > 1 // This threadwise copy allow vector access of src and dst.
template <class SrcFloat, // It allows the dimensions of vector access to be different on src and dst.
class DesFloat, // It also allows the vector size to be different on src and dst.
class SrcDesc, // It also allows order of access to be different on src and dst.
// It use register as buffer to hold all data moving from src to dst.
// It is designed for copying small amount of data, and src and dst are
// device memory or LDS.
// When copying large amout of data, let's hope compiler will reduce register
// used for the buffer.
template <class SrcDesc,
class DstDesc,
class SliceLengths,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v1r1
{
static constexpr index_t nDim = SliceLengths::GetSize();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r1(
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r1()
: ThreadwiseGenericTensorSliceCopy_v1r1(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class SrcData, class DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
SrcData p_src_buffer_[buffer_desc.GetElementSpace()];
SrcData* p_src_buffer = p_src_buffer_;
// copy data from src into src buffer
{
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto src_access_lengths = SliceLengths::Modify(
src_vector_access_dim,
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
static_ford<decltype(src_access_lengths), SrcDimAccessOrder>{}([&](auto src_access_id) {
constexpr auto src_data_begin_id = src_access_id.Modify(
src_vector_access_dim,
src_access_id[src_vector_access_dim] * src_data_per_access);
const index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_begin_id);
// load vector from src
const src_vector_t src_vector_data =
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
// unpack vector into buffer
static_for<0, SrcDataPerAccess, 1>{}([&](auto i) {
constexpr auto scalar_id =
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(src_vector_access_dim,
i);
constexpr index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
p_src_buffer[buffer_offset] =
reinterpret_cast<const SrcData*>(&src_vector_data)[i];
});
});
#else
ford<decltype(src_access_lengths), SrcDimAccessOrder>{}([&](auto src_access_id) {
auto src_data_begin_id = src_access_id;
src_data_begin_id(src_vector_access_dim) =
src_access_id[src_vector_access_dim] * src_data_per_access;
const index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_begin_id);
// load vector from src
const src_vector_t src_vector_data =
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
// unpack vector into buffer
for(index_t i = 0; i < SrcDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(src_vector_access_dim) = i;
const index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
p_src_buffer[buffer_offset] =
reinterpret_cast<const SrcData*>(&src_vector_data)[i];
}
});
#endif
}
// copy data from buffer to dst
{
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto dst_access_lengths = SliceLengths::Modify(
dst_vector_access_dim,
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
static_ford<decltype(dst_access_lengths), DstDimAccessOrder>{}([&](auto dst_access_id) {
constexpr auto dst_data_begin_id = dst_access_id.Modify(
dst_vector_access_dim,
dst_access_id[dst_vector_access_dim] * dst_data_per_access);
dst_vector_t dst_vector_data;
// pack vector from buffer and type conversion
static_for<0, DstDataPerAccess, 1>{}([&](auto i) {
constexpr auto scalar_id =
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(dst_vector_access_dim,
i);
constexpr index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(dst_data_begin_id + scalar_id);
// SrcData to DstData type conversion is done here
reinterpret_cast<DstData*>(&dst_vector_data)[i] =
type_convert<DstData>{}(p_src_buffer[buffer_offset]);
});
const index_t dst_offset =
DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_begin_id);
// store vector into dst
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) = dst_vector_data;
});
#else
ford<decltype(dst_access_lengths), DstDimAccessOrder>{}([&](auto dst_access_id) {
auto dst_data_begin_id = dst_access_id;
dst_data_begin_id(dst_vector_access_dim) =
dst_access_id[dst_vector_access_dim] * dst_data_per_access;
dst_vector_t dst_vector_data;
// pack vector from buffer and type conversion
for(index_t i = 0; i < DstDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(dst_vector_access_dim) = i;
const index_t buffer_offset =
buffer_desc.GetOffsetFromMultiIndex(dst_data_begin_id + scalar_id);
// SrcData to DstData type conversion is done here
reinterpret_cast<DstData*>(&dst_vector_data)[i] =
type_convert<DstData>{}(p_src_buffer[buffer_offset]);
}
const index_t dst_offset =
DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_begin_id);
// store vector into dst
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) = dst_vector_data;
});
#endif
}
}
private:
Array<index_t, nDim> mSrcSliceOrigin;
Array<index_t, nDim> mDstSliceOrigin;
};
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
template <class SrcDesc,
class DstDesc, class DstDesc,
class SliceLengths, class SliceLengths,
class DimAccessOrder, class DimAccessOrder,
index_t DataPerAccess> index_t VectorAccessDim,
__device__ void threadwise_generic_tensor_slice_copy_v1( index_t SrcDataPerAccess,
SrcDesc, index_t DstDataPerAccess>
const SrcFloat* __restrict__ p_src, struct ThreadwiseGenericTensorSliceCopy_v1r2
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
DstDesc,
DesFloat* __restrict__ p_dst,
Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_id_begin,
SliceLengths,
DimAccessOrder,
Number<DataPerAccess>)
{ {
constexpr index_t nDim = SrcDesc::GetNumOfDimension(); static constexpr index_t nDim = SliceLengths::GetSize();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2(
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == DimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
static_assert(
SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() && static_if<!SrcDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
nDim == SliceLengths::GetSize() && nDim == DimAccessOrder::GetSize(), static_assert(
"wrong! # of dimensions not the same"); (fwd(SrcDesc{}).GetStride(vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else([&](auto fwd) {
static_assert((fwd(SrcDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(vector_access_dim)>{}([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
}).Else([&](auto fwd) {
static_assert((fwd(DstDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid"); __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2()
: ThreadwiseGenericTensorSliceCopy_v1r2(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
// TODO: do more sanity-check here, something like: __device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
// constexpr auto src_strides_in_access_order = {
// SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{}); mSrcSliceOrigin = src_slice_origin;
}
// constexpr auto dst_strides_in_access_order = __device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
// SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{}); {
mDstSliceOrigin = dst_slice_origin;
}
// // check src/dst stride on the lowest access dimension template <class SrcData, class DstData>
// static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) && __device__ void Run(const SrcData* p_src, DstData* p_dst) const
// (DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1), {
// "wrong! src/dst stride on the lowest access dimension needs to be 1 for " using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
// "vectorized read/write"); using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto slice_lengths_in_access_order = constexpr auto vector_access_dim = Number<VectorAccessDim>{};
SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
// check slice length on the lowest access dimension constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
static_assert(slice_lengths_in_access_order.Back() % DataPerAccess == 0, constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
"wrong! slice length on the lowest access dimension should be evenly divided by "
"DataPerAccess");
constexpr index_t num_access_on_lowest_access_dimension = constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
slice_lengths_in_access_order.Back() / DataPerAccess;
constexpr auto access_lengths = slice_lengths_in_access_order.Modify( constexpr auto long_vector_access_lengths = SliceLengths::Modify(
Number<nDim - 1>{}, Number<num_access_on_lowest_access_dimension>{}); vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
using vector_src_t = typename vector_type<SrcFloat, DataPerAccess>::MemoryType; #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
using vector_dest_t = typename vector_type<DesFloat, DataPerAccess>::MemoryType; static_ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&](
auto long_vector_access_id) {
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 // data id w.r.t slicing-window
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) { constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
constexpr index_t itmp = access_multi_id.Back() * DataPerAccess; vector_access_dim, long_vector_access_id[vector_access_dim] * long_vector_size);
constexpr auto data_multi_id_in_access_order = // buffer to hold a long-vector
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{}); SrcData p_src_long_vector[long_vector_size];
DstData p_dst_long_vector[long_vector_size];
constexpr auto data_multi_id = reorder_array_given_old2new( // load data from src to the long-vector buffer
sequence2array(data_multi_id_in_access_order), DimAccessOrder{}); static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
constexpr auto scalar_id = typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
vector_access_dim, i * src_data_per_access);
const index_t src_index = const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id));
const index_t dst_index = constexpr index_t buffer_offset = i * src_data_per_access;
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) { *reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) = *reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]); });
}).Else([&](auto) {
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx) // type conversion
for(index_t i = 0; i < long_vector_size; ++i)
{ {
p_dst[dst_index + data_idx] = CVT_ACCUM2FLOAT(p_src[src_index + data_idx]); p_dst_long_vector[i] = type_convert<DstType>{}(p_src_long_vector[i]);
} }
// store data from the long-vector buffer to dst
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
constexpr auto scalar_id = typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
vector_access_dim, i * dst_data_per_access);
constexpr index_t buffer_offset = i * dst_data_per_access;
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
});
}); });
});
#else #else
ford<decltype(access_lengths)>{}([&](auto access_multi_id) { ford<decltype(long_vector_access_lengths), DimAccessOrder>{}(
auto data_multi_id_in_access_order = access_multi_id; [&](auto long_vector_access_id) {
data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
const auto data_multi_id = // data id w.r.t slicing-window
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); auto long_vector_data_begin_id = long_vector_access_id;
long_vector_data_begin_id(vector_access_dim) =
long_vector_size * long_vector_access_id[vector_access_dim];
const index_t src_index = // buffer to hold a long-vector
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); SrcData p_src_long_vector[long_vector_size];
DstData p_dst_long_vector[long_vector_size];
const index_t dst_index = // load data from src to the long-vector buffer
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id); for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * src_data_per_access;
static_if<std::is_same<vector_src_t, vector_dest_t>::value>{}([&](auto) { const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(
*reinterpret_cast<vector_dest_t*>(&p_dst[dst_index]) = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id));
*reinterpret_cast<const vector_src_t*>(&p_src[src_index]);
}).Else([&](auto) { const index_t buffer_offset = i * src_data_per_access;
for(unsigned int data_idx = 0; data_idx < DataPerAccess; ++data_idx)
{ *reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
p_dst[dst_index + data_idx] = CVT_ACCUM2FLOAT(p_src[src_index + data_idx]); *reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
} }
});
}); // type conversion
for(index_t i = 0; i < long_vector_size; ++i)
{
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
}
// store data from the long-vector buffer to dst
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(vector_access_dim) = i * dst_data_per_access;
const index_t buffer_offset = i * dst_data_per_access;
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
}
});
#endif #endif
} }
private:
Array<index_t, nDim> mSrcSliceOrigin;
Array<index_t, nDim> mDstSliceOrigin;
};
// This threadwise copy allow vector access of src and dst.
// It allows the dimensions of vector access to be different on src and dst.
// It also allows the vector size to be different on src and dst.
// It also allows order of access to be different on src and dst.
// It use register as buffer to hold all data moving from src to dst.
// It is designed for copying small amount of data, and src and dst are
// device memory or LDS.
// When copying large amout of data, let's hope compiler will reduce register
// used for the buffer.
template <class SrcDesc,
class DstDesc,
class SrcCoordinate,
class DstCoordinate,
class SliceLengths,
class SrcDimAccessOrder,
class DstDimAccessOrder,
index_t SrcVectorAccessDim,
index_t DstVectorAccessDim,
index_t SrcDataPerAccess,
index_t DstDataPerAccess>
struct ThreadwiseGenericTensorSliceCopy_v2r1
{
static constexpr index_t nDim = SliceLengths::GetSize();
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1(SrcCoordinate src_slice_origin,
DstCoordinate dst_slice_origin)
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
{
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SrcDimAccessOrder::GetSize() &&
nDim == DstDimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
is_valid_sequence_map<DstDimAccessOrder>::value,
"wrong! map is not valid");
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
"wrong! cannot evenly divide");
// check vectorized memory access
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
SrcDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
[&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
})
.Else([&](auto fwd) {
static_assert(
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
DstDataPerAccess == 1),
"wrong! vectorized access is allowed only if stride == 1");
});
}
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1()
: ThreadwiseGenericTensorSliceCopy_v2r1(make_zero_array<index_t, nDim>(),
make_zero_array<index_t, nDim>())
{
}
__device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin)
{
mSrcSliceOrigin = src_slice_origin;
}
__device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin)
{
mDstSliceOrigin = dst_slice_origin;
}
template <class TDesc, class Lengths>
struct IsolateMergedDimLengths
{
template <class IDim>
__device__ constexpr index_t operator()(IDim idim) const
{
return TDesc::ContainMultipleOriginalDimensions(idim) ? Lengths{}[idim] : 1;
}
};
template <class SrcTData, class DstTData>
__device__ void Run(const SrcTData* p_src, DstTData* p_dst) const
{
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
SrcTData p_buffer_[buffer_desc.GetElementSpace()];
SrcTData* p_buffer = p_buffer_;
// copy data from src into buffer
{
using src_vector_t = typename vector_type<SrcTData, SrcDataPerAccess>::MemoryType;
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
constexpr auto src_access_lengths = SliceLengths::Modify(
src_vector_access_dim,
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
// Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t
// normal dimensions is known at compile time.
// Below is a hack to isolate merged dimension id from normal dimension id, so the
// corresponding offset can be calculated seperately at run-time and compile-time.
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// merged dimensions, and has value = 1 on normal dimensions;
// src_merged_dim_access_lengths has the same value as src_access_lengths on src's
// normal dimensions, and has value = 1 on merged dimensions;
constexpr auto src_merged_dim_access_lengths = typename sequence_gen<
nDim,
IsolateMergedDimLengths<SrcDesc, decltype(src_access_lengths)>>::type{};
constexpr auto src_normal_dim_access_lengths =
src_access_lengths + Number<1>{} - src_merged_dim_access_lengths;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
// offset w.r.t. merged dimension need to be computed at run-time
static_ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_merged_dim_access_id_) {
constexpr auto src_merged_dim_access_id = decltype(src_merged_dim_access_id_){};
constexpr auto src_merged_dim_data_id = src_merged_dim_access_id.Modify(
src_vector_access_dim,
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access);
const SrcTData* p_src_tmp =
p_src + (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
// offset w.r.t. normal dimension can be computed at compile-time
static_ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_normal_dim_access_id_) {
constexpr auto src_normal_dim_access_id = decltype(src_normal_dim_access_id_){};
constexpr auto src_normal_dim_data_id = src_normal_dim_access_id.Modify(
src_vector_access_dim,
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access);
constexpr index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
// load vector from src
const src_vector_t vector_data =
*reinterpret_cast<const src_vector_t*>(&p_src_tmp[src_normal_offset]);
// unpack vector into buffer
static_for<0, SrcDataPerAccess, 1>{}([&](auto i) {
constexpr auto scalar_id =
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
src_vector_access_dim, i);
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
p_buffer[buffer_offset] =
reinterpret_cast<const SrcTData*>(&vector_data)[i];
});
});
});
#else
ford<decltype(src_merged_dim_access_lengths), SrcDimAccessOrder>{}(
[&](auto src_merged_dim_access_id) {
auto src_merged_dim_data_id = src_merged_dim_access_id;
src_merged_dim_data_id(src_vector_access_dim) =
src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access;
const SrcTData* p_src_tmp =
p_src + (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset();
// these should be compile-time known
ford<decltype(src_normal_dim_access_lengths), SrcDimAccessOrder>{}([&](
auto src_normal_dim_access_id) {
auto src_normal_dim_data_id = src_normal_dim_access_id;
src_normal_dim_data_id(src_vector_access_dim) =
src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access;
const index_t src_normal_offset =
SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id);
// load vector from src
const src_vector_t vector_data =
*reinterpret_cast<const src_vector_t*>(&p_src_tmp[src_normal_offset]);
// unpack vector into buffer
for(index_t i = 0; i < SrcDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(src_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
p_buffer[buffer_offset] =
reinterpret_cast<const SrcTData*>(&vector_data)[i];
}
});
});
#endif
}
// copy data from buffer into dst
{
using dst_vector_t = typename vector_type<DstTData, DstDataPerAccess>::MemoryType;
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
constexpr auto dst_access_lengths = SliceLengths::Modify(
dst_vector_access_dim,
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
constexpr auto dst_merged_dim_access_lengths = typename sequence_gen<
nDim,
IsolateMergedDimLengths<DstDesc, decltype(dst_access_lengths)>>::type{};
constexpr auto dst_normal_dim_access_lengths =
dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths;
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
// offset w.r.t. merged dimension need to be computed at run-time
static_ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_merged_dim_access_id_) {
constexpr auto dst_merged_dim_access_id = decltype(dst_merged_dim_access_id_){};
constexpr auto dst_merged_dim_data_id = dst_merged_dim_access_id.Modify(
dst_vector_access_dim,
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access);
DstTData* p_dst_tmp =
p_dst + (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
// offset w.r.t. normal dimension can be computed at compile-time
static_ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_normal_dim_access_id_) {
constexpr auto dst_normal_dim_access_id = decltype(dst_normal_dim_access_id_){};
constexpr auto dst_normal_dim_data_id = dst_normal_dim_access_id.Modify(
dst_vector_access_dim,
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access);
dst_vector_t vector_data{};
// pack vector from buffer
static_for<0, DstDataPerAccess, 1>{}([&](auto i) {
constexpr auto scalar_id =
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
dst_vector_access_dim, i);
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
reinterpret_cast<DstTData*>(&vector_data)[i] =
type_convert<DstTData>{}(p_buffer[buffer_offset]);
});
constexpr index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
// write vector into dst
*reinterpret_cast<dst_vector_t*>(&p_dst_tmp[dst_normal_offset]) = vector_data;
});
});
#else
// offset w.r.t. merged dimension need to be computed at run-time
ford<decltype(dst_merged_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_merged_dim_access_id) {
auto dst_merged_dim_data_id = dst_merged_dim_access_id;
dst_merged_dim_data_id(dst_vector_access_dim) =
dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
DstTData* p_dst_tmp =
p_dst + (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset();
// offset w.r.t. normal dimension can be computed at compile-time
ford<decltype(dst_normal_dim_access_lengths), DstDimAccessOrder>{}([&](
auto dst_normal_dim_access_id) {
auto dst_normal_dim_data_id = dst_normal_dim_access_id;
dst_normal_dim_data_id(dst_vector_access_dim) =
dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access;
dst_vector_t vector_data{};
// pack vector from buffer
for(index_t i = 0; i < DstDataPerAccess; ++i)
{
auto scalar_id = make_zero_array<index_t, nDim>();
scalar_id(dst_vector_access_dim) = i;
const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id);
reinterpret_cast<DstTData*>(&vector_data)[i] =
type_convert<DstTData>{}(p_buffer[buffer_offset]);
}
const index_t dst_normal_offset =
DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id);
// write vector into dst
*reinterpret_cast<dst_vector_t*>(&p_dst_tmp[dst_normal_offset]) = vector_data;
});
});
#endif
}
}
// T can be Sequence or Array
template <class T, bool PositiveDirection>
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mSrcSliceOrigin += step_sizes;
}).Else([&](auto) { mSrcSliceOrigin -= step_sizes; });
}
template <class T, bool PositiveDirection>
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
{
static_if<PositiveDirection>{}([&](auto) {
mDstSliceOrigin += step_sizes;
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
}
private:
SrcCoordinate mSrcSliceOrigin;
DstCoordinate mDstSliceOrigin;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -9,7 +9,8 @@ namespace ck { ...@@ -9,7 +9,8 @@ namespace ck {
template <class TData, index_t NSize> template <class TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using Type = Array<TData, NSize>;
using data_type = TData;
static constexpr index_t nSize = NSize; static constexpr index_t nSize = NSize;
...@@ -20,7 +21,7 @@ struct Array ...@@ -20,7 +21,7 @@ struct Array
{ {
} }
__host__ __device__ constexpr index_t GetSize() const { return NSize; } __host__ __device__ static constexpr index_t GetSize() { return NSize; }
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const __host__ __device__ constexpr TData operator[](Number<I>) const
...@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, ...@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
return result; return result;
} }
// Array += Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a + b;
return a;
}
// Array -= Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a - b;
return a;
}
// Array = Array + Sequence // Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is> template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
......
...@@ -6,41 +6,63 @@ ...@@ -6,41 +6,63 @@
namespace ck { namespace ck {
template <class Seq> template <index_t...>
struct Sequence;
template <class Seq, index_t I>
struct sequence_split;
template <class>
struct sequence_reverse;
template <class>
struct sequence_map_inverse;
template <class>
struct is_valid_sequence_map; struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
using Type = Sequence; using Type = Sequence;
using data_type = index_t;
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr index_t GetSize() { return mSize; } __host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
template <index_t I> __host__ __device__ static constexpr index_t GetImpl(index_t I)
__host__ __device__ static constexpr index_t Get(Number<I>)
{ {
static_assert(I < mSize, "wrong! I too large");
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0}; const index_t mData[mSize + 1] = {Is..., 0};
return mData[I]; return mData[I];
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ static constexpr auto Get(Number<I>)
{ {
return Number<Get(Number<I>{})>{}; static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
} }
// make sure I is constepxr __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
__host__ __device__ constexpr index_t operator[](index_t I) const
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
{ {
const index_t mData[mSize + 1] = {Is..., 0}; return Get(Number<I>{});
return mData[I];
} }
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{ {
...@@ -52,23 +74,38 @@ struct Sequence ...@@ -52,23 +74,38 @@ struct Sequence
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::Get(Number<IRs>{})...>{};
} }
__host__ __device__ static constexpr auto Reverse(); // MapOld2New is Sequence<...>
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::GetSize() == GetSize(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
}
__host__ __device__ static constexpr index_t Front() __host__ __device__ static constexpr auto Reverse()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; return typename sequence_reverse<Type>::type{};
return mData[0];
} }
__host__ __device__ static constexpr index_t Back() __host__ __device__ static constexpr auto Front()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; static_assert(mSize > 0, "wrong!");
return mData[mSize - 1]; return Get(Number<0>{});
}
__host__ __device__ static constexpr auto Back()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
__host__ __device__ static constexpr auto PopBack(); __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>) __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
...@@ -107,7 +144,16 @@ struct Sequence ...@@ -107,7 +144,16 @@ struct Sequence
} }
template <index_t I, index_t X> template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>); __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <class F> template <class F>
__host__ __device__ static constexpr auto Transform(F f) __host__ __device__ static constexpr auto Transform(F f)
...@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> ...@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>; using type = Sequence<Xs..., Ys...>;
}; };
// arithmetic sqeuence // generate sequence
template <index_t IBegin, index_t NSize, index_t Increment> template <index_t IBegin, index_t NRemain, class F>
struct arithmetic_sequence_gen_impl struct sequence_gen_impl
{ {
static constexpr index_t NSizeLeft = NSize / 2; static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = typename sequence_merge< using type =
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::type, typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment, typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
NSize - NSizeLeft,
Increment>::type>::type;
}; };
template <index_t IBegin, index_t Increment> template <index_t I, class F>
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment> struct sequence_gen_impl<I, 1, F>
{ {
using type = Sequence<IBegin>; static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>;
}; };
template <index_t IBegin, index_t Increment> template <index_t I, class F>
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment> struct sequence_gen_impl<I, 0, F>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
template <index_t NSize, class F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
// arithmetic sequence
template <index_t IBegin, index_t IEnd, index_t Increment> template <index_t IBegin, index_t IEnd, index_t Increment>
struct arithmetic_sequence_gen struct arithmetic_sequence_gen
{ {
using type = typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::type; struct F
{
__host__ __device__ constexpr index_t operator()(index_t i) const
{
return i * Increment + IBegin;
}
};
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
}; };
// uniform sequence // uniform sequence
template <index_t NSize, index_t I> template <index_t NSize, index_t I>
struct uniform_sequence_gen struct uniform_sequence_gen
{ {
struct return_constant struct F
{ {
__host__ __device__ constexpr index_t operator()(index_t) const { return I; } __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
}; };
using type = decltype( using type = typename sequence_gen<NSize, F>::type;
typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{}));
}; };
// reverse inclusive scan (with init) sequence // reverse inclusive scan (with init) sequence
...@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>>
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
// not implemented yet, always return true
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{}; static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
// TODO: add proper check for is_valid, something like: // TODO: add proper check for is_valid, something like:
...@@ -244,6 +306,34 @@ struct is_valid_sequence_map ...@@ -244,6 +306,34 @@ struct is_valid_sequence_map
// typename sequence_sort<Seq>::SortedSeqType>{}; // typename sequence_sort<Seq>::SortedSeqType>{};
}; };
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
private:
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};
template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <class X2Y>
struct sequence_map_inverse
{
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
0,
X2Y::GetSize()>::type;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{ {
...@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
template <class Seq> template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq) __host__ __device__ constexpr auto sequence_pop_back(Seq)
{ {
static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq{}.Reverse()).Reverse(); return sequence_pop_front(Seq::Reverse()).Reverse();
} }
template <class F, index_t... Xs> template <class F, index_t... Xs>
...@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
{
return sequence_pop_front(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
{
return sequence_pop_back(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
{
return typename sequence_reverse<Sequence<Is...>>::type{};
}
template <index_t... Is>
template <index_t I, index_t X>
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{ {
......
...@@ -3,80 +3,111 @@ ...@@ -3,80 +3,111 @@
#include "vector_type.hpp" #include "vector_type.hpp"
#define WORKAROUND_SWDEV_202749 1
namespace ck { namespace ck {
#if !CK_USE_INLINE_ASM_XDLOPS
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t __llvm_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float32_t __llvm_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float32_t __llvm_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
#endif
// cast a pointer of LDS to its address // cast a pointer of LDS to its address
extern "C" __attribute__((address_space(3))) __device__ void* __to_local(void* p);
__device__ void vmcnt(index_t cnt) extern "C" __attribute__((address_space(3))) __device__ void* __to_local(const void* p);
{
if(cnt == 0)
{
asm volatile("\n \
s_waitcnt vmcnt(0) \n \
" ::);
}
else if(cnt == 1)
{
asm volatile("\n \
s_waitcnt vmcnt(1) \n \
" ::);
}
else if(cnt == 2)
{
asm volatile("\n \
s_waitcnt vmcnt(2) \n \
" ::);
}
else if(cnt == 4)
{
asm volatile("\n \
s_waitcnt vmcnt(2) \n \
" ::);
}
else
{
assert(false);
}
}
__device__ void lgkmcnt(index_t cnt) // clang-format off
{ #define REPEATx4(f, off) f(off) f(off + 1) f(off + 2) f(off + 3)
if(cnt == 0)
{ #define REPEATx16(f, off) \
asm volatile("\n \ REPEATx4(f, off) REPEATx4(f, off + 4) REPEATx4(f, off + 8) REPEATx4(f, off + 12)
s_waitcnt lgkmcnt(0) \n \
" ::); #define REPEATx64(f, off) \
} REPEATx16(f, off) REPEATx16(f, off + 16) REPEATx16(f, off + 32) REPEATx16(f, off + 48)
else if(cnt == 1)
{ #define REPEAT_STRIDEx4(f, stride, off) \
asm volatile("\n \ f(off) f(off + 1 * stride) f(off + 2 * stride) f(off + 3 * stride)
s_waitcnt lgkmcnt(1) \n \
" ::); #define REPEAT_STRIDEx16(f, stride, off) \
REPEAT_STRIDEx4(f, stride, off) REPEAT_STRIDEx4(f, stride, off + 1 * stride * 4) \
REPEAT_STRIDEx4(f, stride, off + 2 * stride * 4) \
REPEAT_STRIDEx4(f, stride, off + 3 * stride * 4)
#define REPEAT_STRIDEx64(f, stride, off) \
REPEAT_STRIDEx16(f, stride, off) REPEAT_STRIDEx16(f, stride, off + 1 * stride * 16) \
REPEAT_STRIDEx16(f, stride, off + 2 * stride * 16) \
REPEAT_STRIDEx16(f, stride, off + 3 * stride * 16)
#define NOP(n) asm volatile("\n s_nop " #n " " : :);
#define DS_READ_B32(off) \
if(offset == off) \
{ \
asm volatile("ds_read_b32 %0, %1 offset:" #off " " : "=v"(r) : "v"(__to_local(lds))); \
} }
else if(cnt == 2)
{ #define DS_READ_B128(off) \
asm volatile("\n \ if(offset == off) \
s_waitcnt lgkmcnt(2) \n \ { \
" ::); asm volatile("ds_read_b128 %0, %1 offset:" #off " " : "=v"(r) : "v"(__to_local(lds))); \
} }
else if(cnt == 3)
{ #define DS_WRITE_B128(off) \
asm volatile("\n \ if(offset == off) \
s_waitcnt lgkmcnt(3) \n \ { \
" ::); asm volatile("ds_write_b128 %0, %1 offset:" #off " " : : "v"(__to_local(lds)), "v"(r)); \
} }
else if(cnt == 4)
{ #define MFMA_F32_32x32x1F32(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm volatile("\n \ asm volatile("v_mfma_f32_32x32x1f32 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
s_waitcnt lgkmcnt(4) \n \ "+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
" ::); : \
: "v"(reg_a), "v"(reg_b));
#define MFMA_F32_32x32x4F16(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm volatile("v_mfma_f32_32x32x4f16 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
"+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
: \
: "v"(reg_a), "v"(reg_b));
#define MFMA_F32_32x32x2BF16(acc, reg_a, reg_b, cbsz, abid, blgp) \
asm volatile("v_mfma_f32_32x32x2bf16 a[" #acc ":" #acc "+31], %0, %1, a[" #acc ":" #acc \
"+31] cbsz: " #cbsz " abid: " #abid " blgp:" #blgp " " \
: \
: "v"(reg_a), "v"(reg_b));
#define ACCVGPR_READ(acc_reg_id) \
asm volatile("v_accvgpr_read_b32 %0, a[" #acc_reg_id "]" : "=v"(arch_reg[acc_reg_id]) :);
#define ACCVGPR_WRITE(acc_reg_id) \
asm volatile("v_accvgpr_write_b32 a[" #acc_reg_id "], %0" : : "v"(arch_reg[acc_reg_id]));
#define ACCVGPR_ZERO(acc_reg_id) \
asm volatile("v_accvgpr_write_b32 a[" #acc_reg_id "], 0" : :);
#define S_WAIT_VMCNT(id) \
if(cnt == id) \
{ \
asm volatile("s_waitcnt vmcnt(" #id ")" ::); \
} }
else
{ #define S_WAIT_LGKMCNT(id) \
assert(false); if(cnt == id) \
{ \
asm volatile("s_waitcnt lgkmcnt(" #id ")" ::); \
} }
}
__device__ void s_wait_vmcnt(index_t cnt) { REPEATx4(S_WAIT_VMCNT, 0) }
__device__ void s_wait_lgkmcnt(index_t cnt) { REPEATx4(S_WAIT_LGKMCNT, 0) }
__device__ void outerProduct1x4(const float* a, const float* b, float* c) __device__ void outerProduct1x4(const float* a, const float* b, float* c)
{ {
...@@ -98,6 +129,23 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c) ...@@ -98,6 +129,23 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c)
"3"(c[3])); "3"(c[3]));
} }
__device__ void outerProduct1x2(const float* a, const float* b, float* c)
{
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if WORKAROUND_SWDEV_202749
c[0] += a[0] * b[0];
c[1] += a[0] * b[1];
#else
asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \
"
: "=v"(c[0]), "=v"(c[1])
: "v"(a[0]), "v"(b[0]), "v"(b[1]), "0"(c[0]), "1"(c[1]));
#endif
}
__device__ void outerProduct1x4(const float& a, __device__ void outerProduct1x4(const float& a,
const vector_type<float, 4>::MemoryType& b, const vector_type<float, 4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c) vector_type<float, 4>::MemoryType& c)
...@@ -105,20 +153,14 @@ __device__ void outerProduct1x4(const float& a, ...@@ -105,20 +153,14 @@ __device__ void outerProduct1x4(const float& a,
outerProduct1x4(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c)); outerProduct1x4(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c));
} }
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a, __device__ void outerProduct1x2(const float& a,
const vector_type<float, 4>::MemoryType& b, const vector_type<float, 2>::MemoryType& b,
vector_type<float, 4>::MemoryType& c0, vector_type<float, 2>::MemoryType& c)
vector_type<float, 4>::MemoryType& c1,
vector_type<float, 4>::MemoryType& c2,
vector_type<float, 4>::MemoryType& c3)
{ {
outerProduct1x4(a.x, b, c0); outerProduct1x2(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c));
outerProduct1x4(a.y, b, c1);
outerProduct1x4(a.z, b, c2);
outerProduct1x4(a.w, b, c3);
} }
__device__ void outerProduct1x4(const half2* a, const half2* b, float* c) __device__ void outerProduct1x4dot2TwoTimes(const half2* a, const half2* b, float* c)
{ {
asm volatile("\n \ asm volatile("\n \
v_dot2_f32_f16 %0, %4, %6 %0\n \ v_dot2_f32_f16 %0, %4, %6 %0\n \
...@@ -147,579 +189,240 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c) ...@@ -147,579 +189,240 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
"3"(c[3])); // 3rd Src Acc registers for 2 half2 registers "3"(c[3])); // 3rd Src Acc registers for 2 half2 registers
} }
__device__ void outerProduct1x4Half(const vector_type<half, 4>& a, __device__ void outerProduct1x4dot2(const half2* a, const half2* b, float* c)
const vector_type<vector_type<half, 4>, 4>& b,
vector_type<float, 4>::MemoryType& c)
{ {
outerProduct1x4(reinterpret_cast<const half2*>(&a), asm volatile("\n \
reinterpret_cast<const half2*>(&b), v_dot2_f32_f16 %0, %4, %5 %0\n \
reinterpret_cast<float*>(&c)); v_dot2_f32_f16 %1, %4, %6 %1\n \
v_dot2_f32_f16 %2, %4, %7 %2\n \
v_dot2_f32_f16 %3, %4, %8 %3\n \
"
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3]) // Dest registers
: "v"(a[0]), // 1st Src register for 1 half2 registers
"v"(b[0]), // 2nd Src register
"v"(b[1]),
"v"(b[2]),
"v"(b[3]),
"0"(c[0]), // 3rd Src register
"1"(c[1]),
"2"(c[2]),
"3"(c[3]));
} }
__device__ void outerProduct4x4(const vector_type<vector_type<half, 4>, 4>& a, __device__ void outerProduct1x2dot2TwoTimes(const half2* a, const half2* b, float* c)
const vector_type<vector_type<half, 4>, 4>& b,
vector_type<float, 4>::MemoryType& c0,
vector_type<float, 4>::MemoryType& c1,
vector_type<float, 4>::MemoryType& c2,
vector_type<float, 4>::MemoryType& c3)
{ {
const vector_type<half, 4>* reg_a = reinterpret_cast<const vector_type<half, 4>*>(&a); asm volatile("\n \
outerProduct1x4Half(reg_a[0], b, c0); v_dot2_f32_f16 %0, %2, %4 %0\n \
outerProduct1x4Half(reg_a[1], b, c1); v_dot2_f32_f16 %1, %2, %6 %1\n \
outerProduct1x4Half(reg_a[2], b, c2); v_dot2_f32_f16 %0, %3, %5 %0\n \
outerProduct1x4Half(reg_a[3], b, c3); v_dot2_f32_f16 %1, %3, %7 %1\n \
"
: "=v"(c[0]), "=v"(c[1]) // Dest registers
: "v"(a[0]),
"v"(a[1]), // 1st Src registers for 2 half2 registers
"v"(b[0]),
"v"(b[1]),
"v"(b[2]),
"v"(b[3]), // 2nd Src registers for 2 half2 registers
"0"(c[0]),
"1"(c[1])); // 3rd Src Acc registers for 2 half2 registers
} }
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a, __device__ void outerProduct1x2dot2(const half2* a, const half2* b, float* c)
const vector_type<float, 4>::MemoryType* b,
vector_type<float, 4>::MemoryType* c)
{ {
outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]); asm volatile("\n \
outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]); v_dot2_f32_f16 %0, %2, %3 %0\n \
outerProduct4x4(a[1], b[0], c[8], c[10], c[12], c[14]); v_dot2_f32_f16 %1, %2, %4 %1\n \
outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]); "
: "=v"(c[0]), "=v"(c[1]) // Dest registers
: "v"(a[0]), // 1st Src register for 1 half2 registers
"v"(b[0]), // 2nd Src register
"v"(b[1]),
"0"(c[0]), // 3rd Src register
"1"(c[1]));
} }
__device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0) __device__ void ds_read_b32(float& r, const void* lds, index_t offset = 0) { DS_READ_B32(0) }
__device__ void
ds_read_b128(vector_type<float, 4>::MemoryType& r, const void* lds, index_t offset = 0)
{ {
if(offset == 0) REPEAT_STRIDEx64(DS_READ_B128, 64, 0)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:0\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 64)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:64\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 128)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:128\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 192)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:192\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 256)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:256\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 320)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:320\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 384)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:384\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 448)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:448\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 512)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:512\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 576)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:576\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 640)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:640\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 704)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:704\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 768)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:768\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 832)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:832\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 896)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:896\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 960)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:960\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1024)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1024\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1088)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1088\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1152)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1152\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1216)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1216\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1280)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1280\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1344)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1344\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1408)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1408\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1472)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1472\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1536)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1536\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1600)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1600\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1664)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1664\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1728)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1728\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1792)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1792\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1856)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1856\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1920)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1920\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 1984)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:1984\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2048)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2048\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2112)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2112\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2176)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2176\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2240)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2240\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2304)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2304\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2368)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2368\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2432)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2432\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2496)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2496\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2560)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2560\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2624)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2624\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2688)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2688\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2752)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2752\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2816)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2816\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2880)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2880\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 2944)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:2944\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3008)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3008\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3072)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3072\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3136)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3136\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3200)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3200\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3264)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3264\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3328)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3328\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3392)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3392\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3456)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3456\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3520)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3520\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3584)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3584\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3648)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3648\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3712)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3712\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3776)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3776\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3840)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3840\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3904)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3904\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 3968)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:3968\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 4032)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:4032\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
if(offset == 4096)
{
asm volatile("\n \
ds_read_b128 %0, %1 offset:4096\n \
"
: "=v"(r)
: "v"(__to_local(lds)));
}
} }
__device__ void __device__ void
ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0) ds_write_b128(const vector_type<float, 4>::MemoryType& r, const void* lds, index_t offset = 0)
{ {
if(offset == 0) REPEAT_STRIDEx64(DS_WRITE_B128, 64, 0)
{ }
asm volatile("\n \
ds_write_b128 %0, %1 \n \ template <index_t Size>
" __device__ void gcnasm_accvgpr_read(float*)
: {
: "v"(__to_local(lds)), "v"(r)); }
}
else template <>
{ __device__ void gcnasm_accvgpr_read<16>(float* arch_reg)
assert(false); {
} #if CK_USE_INLINE_ASM_XDLOPS
NOP(16)
REPEATx16(ACCVGPR_READ, 0)
#else
(void)arch_reg;
#endif
}
template <>
__device__ void gcnasm_accvgpr_read<32>(float* arch_reg)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(16)
REPEATx16(ACCVGPR_READ, 0)
REPEATx16(ACCVGPR_READ, 16)
#else
(void)arch_reg;
#endif
}
template <>
__device__ void gcnasm_accvgpr_read<64>(float* arch_reg)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(16)
REPEATx64(ACCVGPR_READ, 0)
#else
(void)arch_reg;
#endif
}
template <index_t MPerWave>
__device__ void gcnasm_accvgpr_zero()
{
}
template <>
__device__ void gcnasm_accvgpr_zero<32>()
{
#if CK_USE_INLINE_ASM_XDLOPS
REPEATx16(ACCVGPR_ZERO, 0)
REPEATx16(ACCVGPR_ZERO, 16)
#endif
}
template <>
__device__ void gcnasm_accvgpr_zero<64>()
{
#if CK_USE_INLINE_ASM_XDLOPS
REPEATx64(ACCVGPR_ZERO, 0)
#endif
}
template <index_t MPerWave>
__device__ void gcnasm_mfma_f32_32x32x1f32(float&, float&, float32_t*)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<64>(float& reg_a, float& reg_b, float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x1F32(0, reg_a, reg_b, 1, 0, 0)
MFMA_F32_32x32x1F32(32, reg_a, reg_b, 1, 1, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[0], 1, 0, 0);
reg_c[1] = __llvm_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[1], 1, 1, 0);
#endif
}
template <>
__device__ void gcnasm_mfma_f32_32x32x1f32<32>(float& reg_a, float& reg_b, float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x1F32(0, reg_a, reg_b, 1, 0, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x1f32(reg_a, reg_b, reg_c[0], 1, 0, 0);
#endif
}
template <index_t MPerWave>
__device__ void gcnasm_mfma_f32_32x32x4f16(typename vector_type<half, 4>::MemoryType&,
typename vector_type<half, 4>::MemoryType&,
float32_t*)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<64>(typename vector_type<half, 4>::MemoryType& reg_a,
typename vector_type<half, 4>::MemoryType& reg_b,
float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x4F16(0, reg_a, reg_b, 1, 0, 0)
MFMA_F32_32x32x4F16(32, reg_a, reg_b, 1, 1, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x4f16(reg_a, reg_b, reg_c[0], 1, 0, 0);
reg_c[1] = __llvm_amdgcn_mfma_f32_32x32x4f16(reg_a, reg_b, reg_c[1], 1, 1, 0);
#endif
}
template <>
__device__ void gcnasm_mfma_f32_32x32x4f16<32>(typename vector_type<half, 4>::MemoryType& reg_a,
typename vector_type<half, 4>::MemoryType& reg_b,
float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x4F16(0, reg_a, reg_b, 1, 0, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x4f16(reg_a, reg_b, reg_c[0], 1, 0, 0);
#endif
} }
template <index_t MPerWave>
__device__ void gcnasm_mfma_f32_32x32x2bf16(typename vector_type<ushort, 2>::MemoryType&,
typename vector_type<ushort, 2>::MemoryType&,
float32_t*)
{
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<64>(typename vector_type<ushort, 2>::MemoryType& reg_a,
typename vector_type<ushort, 2>::MemoryType& reg_b,
float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x2BF16(0, reg_a, reg_b, 1, 0, 0)
MFMA_F32_32x32x2BF16(32, reg_a, reg_b, 1, 1, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x2bf16(reg_a, reg_b, reg_c[0], 1, 0, 0);
reg_c[1] = __llvm_amdgcn_mfma_f32_32x32x2bf16(reg_a, reg_b, reg_c[1], 1, 1, 0);
#endif
}
template <>
__device__ void gcnasm_mfma_f32_32x32x2bf16<32>(typename vector_type<ushort, 2>::MemoryType& reg_a,
typename vector_type<ushort, 2>::MemoryType& reg_b,
float32_t* reg_c)
{
#if CK_USE_INLINE_ASM_XDLOPS
NOP(1)
(void)reg_c;
MFMA_F32_32x32x2BF16(0, reg_a, reg_b, 1, 0, 0)
#else
reg_c[0] = __llvm_amdgcn_mfma_f32_32x32x2bf16(reg_a, reg_b, reg_c[0], 1, 0, 0);
#endif
}
// clang-format on
} // namespace ck } // namespace ck
#endif #endif
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#ifndef BFLOAT16_DEVICE_HPP #ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP #define BFLOAT16_DEVICE_HPP
#define __HIP_PLATFORM_HCC__ 1
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment