Commit 32850b93 authored by Wen-Heng (Jack) Chung's avatar Wen-Heng (Jack) Chung
Browse files

Ported xdlops kernels to debug bwdwrw fp32/fp16/bfp16 issue. Verified atleast fwd data fp32 works.

parent 583755a7
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
#define GUARD_MIOPEN_IMPLICITGEMM_PARMS_HPP_
enum struct ImplicitGemmDirection
{
ForwardData,
BackwardData,
BackwardWeight
};
enum struct ImplicitGemmXdlopsKernel
{
KernelFwdWrw = 0,
Kernel1x1 = 1,
};
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
...@@ -8,9 +8,35 @@ ...@@ -8,9 +8,35 @@
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck { namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc>
struct make_WeiDesc
{
};
template <typename WeiDesc>
struct make_WeiDesc<ImplicitGemmDirection::ForwardData, WeiDesc>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <typename WeiDesc>
struct make_WeiDesc<ImplicitGemmDirection::BackwardWeight, WeiDesc>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -24,8 +50,7 @@ template <index_t GridSize, ...@@ -24,8 +50,7 @@ template <index_t GridSize,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
index_t N1, index_t GemmNRepeat,
index_t N2,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -48,17 +73,19 @@ template <index_t GridSize, ...@@ -48,17 +73,19 @@ template <index_t GridSize,
class WeiBlockCopySrcAccessOrder, class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void __device__ void Run(const Float* const __restrict__ p_in_global,
Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_wei_global, Float* const __restrict__ p_out_global) const
Float* const __restrict__ p_out_global) const
{ {
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); constexpr index_t N1 = GemmNRepeat;
constexpr index_t N2 = GemmNPerThreadSubC;
static_assert((N1 * N2 * BPerBlock) % static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
0, 0,
...@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -86,6 +113,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
constexpr index_t N0 = N / (N1 * N2); constexpr index_t N0 = N / (N1 * N2);
...@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -94,6 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
constexpr index_t E = C * Y * X; constexpr index_t E = C * Y * X;
// sanity-check for vectorized memory load
static_assert(ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1,
"wrong! global vector load of input tensor is wrong");
static_assert((X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B] // divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block"); "wrong! cannot divide work evenly among block");
...@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -113,15 +154,15 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{}) .StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{}) .StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
...@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -148,7 +189,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
...@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -157,6 +197,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
InBlockCopyThreadClusterArrangeOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder, InBlockCopyDstAccessOrder,
2,
3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
...@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -164,7 +206,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); make_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -177,7 +220,6 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float,
decltype(wei_e_k_global_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
...@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -186,6 +228,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
...@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -196,13 +240,11 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<EPerBlock>{}, Number<KPerBlock>{}, Number<wei_e_k_block_desc.GetStride(I0)>{}); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_n1bn2_block_mtx_desc = constexpr auto b_e_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<EPerBlock>{}, make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
Number<N1 * BPerBlock * N2>{},
Number<in_e_n1_b_n2_block_desc.GetStride(I0)>{});
// sanity check // sanity check
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
...@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -214,11 +256,12 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{}); Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
1, // EPACK = 1
decltype(a_e_k_block_mtx_desc), decltype(a_e_k_block_mtx_desc),
decltype(b_e_n1bn2_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc),
decltype(c_k0k2_n1n2_thread_mtx_desc), decltype(c_k0k2_n1n2_thread_mtx_desc),
...@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -280,53 +323,58 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
Float* p_wei_block_next = Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto fwd) {
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_clipboard); p_wei_register_buffer);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
p_in_block_next); blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_next);
} }
} }
// LDS double buffer: tail // LDS double buffer: tail
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
// even iteration // even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True); Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard); blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global, blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
p_wei_register_clipboard);
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space); p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space); p_wei_block_double + wei_block_space);
// odd iteration // odd iteration
__syncthreads(); __syncthreads();
...@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -384,19 +432,18 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
k_thread_data_on_global, 0, b_thread_data_on_global, 0); k_thread_data_on_global, 0, b_thread_data_on_global, 0);
threadwise_generic_tensor_slice_copy_v1( ThreadwiseGenericTensorSliceCopy_v1r2<
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
p_out_thread, decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
{0, 0, 0, 0, 0, 0, 0, 0}, decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, arithmetic_sequence_gen<0, 8, 1>::type,
p_out_thread_on_global, 7,
{0, 0, 0, 0, 0, 0, 0, 0}, 1,
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), 1>(make_zero_array<index_t, 8>(), make_zero_array<index_t, 8>())
arithmetic_sequence_gen<0, 8, 1>::type{}, .Run(p_out_thread, p_out_thread_on_global);
Number<1>{});
} }
} }
}; };
} // namespace ck } // namespace ck
#endif #endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KC1X1_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
ImplicitGemmDirection Direction,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
bool EnableXdlops,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kc1x1_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_c_k_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr bool isForward = Direction == ImplicitGemmDirection::ForwardData;
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I2);
constexpr index_t Wo =
std::conditional<isForward,
decltype(out_n_k_h_w_global_desc),
decltype(in_n_c_h_w_global_desc)>::type::GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t E = C;
constexpr index_t B = N * Ho * Wo;
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc_forw =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
constexpr auto in_n_ho_wo_global_desc_back = in_n_c_h_w_global_desc.Extract(I0, I2, I3);
constexpr auto in_n_ho_wo_global_desc =
typename std::conditional<isForward,
decltype(in_n_ho_wo_global_desc_forw),
decltype(in_n_ho_wo_global_desc_back)>::type{};
// batch descritpor for device memory
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Extract(I1);
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc = make_ConstantMergedTensorDescriptor(
in_c_global_desc.Embed(in_n_ho_wo_global_desc), Sequence<0>{}, Sequence<1, 2, 3>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, BPerBlock>{},
Number<math::lcm(InBlockCopyDataPerAccess_B, GemmDataPerReadB)>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc_forw = wei_c_k_global_desc;
constexpr auto wei_e_k_global_desc_back =
make_ConstantTensorDescriptor_packed(Sequence<C, K>{});
constexpr auto wei_e_k_global_desc =
typename std::conditional<isForward,
decltype(wei_e_k_global_desc_forw),
decltype(wei_e_k_global_desc_back)>::type{};
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(mfma_info<float>{}),
EnableXdlops,
GemmMPerWave,
GemmNPerWave,
GemmMWaves,
GemmNWaves,
GemmDataPerReadA,
GemmDataPerReadB>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
Float p_out_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_out_thread);
static_if<EnableXdlops>{}(
[&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// load data from xldop_acc_regs
static_if<EnableXdlops>{}([&](auto) {
gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
});
// copy output: register to global memory
{
constexpr index_t K2 = blockwise_gemm.OutputLayout.M2;
constexpr index_t K1 = blockwise_gemm.OutputLayout.M1;
constexpr index_t K0 = blockwise_gemm.OutputLayout.M0;
constexpr auto out_n_k_h_w_global_desc_forw = out_n_k_h_w_global_desc;
constexpr auto out_lengths_back =
Sequence<out_n_k_h_w_global_desc.GetLength(I0),
out_n_k_h_w_global_desc.GetLength(I1),
math::integer_divide_ceil(out_n_k_h_w_global_desc.GetLength(I2),
ConvStrides{}.Get(I0)),
math::integer_divide_ceil(out_n_k_h_w_global_desc.GetLength(I3),
ConvStrides{}.Get(I1))>{};
constexpr auto out_strides_back =
Sequence<out_n_k_h_w_global_desc.GetStride(I0),
out_n_k_h_w_global_desc.GetStride(I1),
out_n_k_h_w_global_desc.GetStride(I2) * ConvStrides{}.Get(I0),
out_n_k_h_w_global_desc.GetStride(I3) * ConvStrides{}.Get(I1)>{};
constexpr auto out_n_k_h_w_global_desc_back =
make_ConstantTensorDescriptor(out_lengths_back, out_strides_back);
constexpr auto out_n_k_h_w_global_desc_new =
typename std::conditional<isForward,
decltype(out_n_k_h_w_global_desc_forw),
decltype(out_n_k_h_w_global_desc_back)>::type{};
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr auto out_k0_k1_k2_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc_new.Fold(I1, Number<K1>{}, Number<K2>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<0, 4, 5>{});
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_ConstantTensorDescriptor_packed(Sequence<K2, 1, K0, 1>{});
using OutThreadCopySliceLengths = Sequence<K2, 1, K0, 1>;
constexpr index_t NumKPerBlk = out_k0_k1_k2_b_thread_desc.GetElementSpace();
constexpr index_t NumBlks = GemmMPerWave / NumKPerBlk;
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / (K0 * K1),
k_thread_data_on_global % (K0 * K1) / K0,
k_thread_data_on_global % K0,
b_thread_data_on_global});
threadwise_out_copy.Run(p_out_thread + i * NumKPerBlk, p_out_global);
}
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "implicitgemm_params.hpp"
namespace ck {
template <ImplicitGemmDirection conv_dir, typename WeiDesc>
struct make_WeiDesc_Xdlops
{
};
template <typename WeiDesc>
struct make_WeiDesc_Xdlops<ImplicitGemmDirection::ForwardData, WeiDesc>
{
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template <typename WeiDesc>
struct make_WeiDesc_Xdlops<ImplicitGemmDirection::BackwardWeight, WeiDesc>
{
__device__ constexpr auto get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Float,
class AccDataType,
class InGlobalDesc,
class WeiGlobalDesc,
class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock,
index_t KPerBlock,
index_t EPerBlock,
index_t EPack,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
bool EnableXdlops,
class InBlockCopySubLengths_E_B,
class InBlockCopyClusterLengths_E_B,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_B,
ImplicitGemmDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t nonVectorizedC = C / EPack;
constexpr index_t E = nonVectorizedC * Y * X;
constexpr index_t B = N * Ho * Wo;
static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_multi_id =
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
// input tensor
// tensor descriptor in device memory [N, Ho, Wo]
constexpr auto in_n_ho_wo_global_desc =
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
// batch descritpor for device memory
constexpr auto in_c_y_x_global_desc =
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, B], src of blockwise copy
constexpr auto in_e_b_global_desc =
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
Sequence<0, 1, 2>{},
Sequence<3, 4, 5>{});
// memory layout descriptor in LDS [E, B], dst of blockwise copy
// be careful of LDS alignment
constexpr auto in_e_b_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, BPerBlock>{},
Number<math::lcm(InBlockCopyDataPerAccess_B, GemmDataPerReadB)>{});
// input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc =
make_WeiDesc_Xdlops<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(
wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
BlockSize,
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
0,
1,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({0, k_block_data_on_global}, {0, 0});
// GEMM definition
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops<
BlockSize,
decltype(a_e_k_block_mtx_desc),
decltype(b_e_b_block_mtx_desc),
decltype(mfma_info<float>{}),
EnableXdlops,
GemmMPerWave,
GemmNPerWave,
GemmMWaves,
GemmNWaves,
GemmDataPerReadA,
GemmDataPerReadB>{};
constexpr auto c_k_thread_mtx_desc = blockwise_gemm.GetThreadMatrixCDescriptor();
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
__shared__ Float p_in_block_double[2 * in_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccDataType p_out_thread[c_k_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_k_thread_mtx_desc, p_out_thread);
// static_if<EnableXdlops>{}(
// [&](auto) { gcnasm_accvgpr_zero<c_k_thread_mtx_desc.GetElementSpace()>(); });
const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS
{
blockwise_in_copy.Run(p_in_global, p_in_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
e_block_data_begin += 2 * EPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_in_block_now =
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_in_block_next =
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ImplicitGemmDirection::BackwardWeight>{}([&](auto ) {
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else([&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
// LDS double buffer: store next data to LDS
blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer,
p_in_block_double + in_block_space);
blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer,
p_wei_block_double + wei_block_space);
// odd iteration
__syncthreads();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_in_block_double + in_block_space,
p_out_thread);
}
// load data from xldop_acc_regs
// static_if<EnableXdlops>{}([&](auto) {
// gcnasm_accvgpr_read<c_k_thread_mtx_desc.GetElementSpace()>(p_out_thread);
// });
// copy output: register to global memory
{
constexpr index_t K2 = blockwise_gemm.OutputLayout.M2;
constexpr index_t K1 = blockwise_gemm.OutputLayout.M1;
constexpr index_t K0 = blockwise_gemm.OutputLayout.M0;
// This is a hack, because slicing a merged dimension is not supported yet.
// dst descriptor
constexpr auto out_k0_k1_k2_b_global_desc = make_ConstantMergedTensorDescriptor(
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{}),
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<0, 4, 5>{});
// src descriptor
constexpr auto out_k0_k1_k2_b_thread_desc =
make_ConstantTensorDescriptor_packed(Sequence<K2, 1, K0, 1>{});
using OutThreadCopySliceLengths = Sequence<K2, 1, K0, 1>;
constexpr index_t NumKPerBlk = out_k0_k1_k2_b_thread_desc.GetElementSpace();
constexpr index_t NumBlks = GemmMPerWave / NumKPerBlk;
for(index_t i = 0; i < NumBlks; ++i)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block = blockwise_gemm.GetBeginOfThreadMatrixC(i);
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col;
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_k2_b_thread_desc),
decltype(out_k0_k1_k2_b_global_desc),
NormalTensorCoordinate<decltype(out_k0_k1_k2_b_thread_desc)>,
MergedTensorCoordinate<decltype(out_k0_k1_k2_b_global_desc)>,
OutThreadCopySliceLengths,
arithmetic_sequence_gen<0, 4, 1>::type,
arithmetic_sequence_gen<0, 4, 1>::type,
3,
3,
OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
{k_thread_data_on_global / (K0 * K1),
k_thread_data_on_global % (K0 * K1) / K0,
k_thread_data_on_global % K0,
b_thread_data_on_global});
threadwise_out_copy.Run(p_out_thread + i * NumKPerBlk, p_out_global);
}
}
}
};
} // namespace ck
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck { namespace ck {
...@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor ...@@ -39,7 +40,7 @@ struct ConstantMatrixDescriptor
}; };
template <index_t NRow, index_t NCol> template <index_t NRow, index_t NCol>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>) __host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number<NRow>, Number<NCol>)
{ {
return ConstantMatrixDescriptor<NRow, NCol, NCol>{}; return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
} }
...@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto ...@@ -51,6 +52,17 @@ __host__ __device__ constexpr auto
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{}; return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
} }
template <class... Ts>
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
{
using TDesc = ConstantTensorDescriptor<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
TDesc::GetLengths()[1],
TDesc::GetStrides()[0]>{};
}
template <class TDesc> template <class TDesc>
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
{ {
......
...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor ...@@ -37,7 +37,7 @@ struct ConstantMergedTensorDescriptor
return OriginalTensorDesc{}; return OriginalTensorDesc{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>) __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor ...@@ -52,7 +52,7 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetLength(Number<IDim>) __host__ __device__ static constexpr auto GetLength(Number<IDim>)
{ {
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs); constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
...@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor ...@@ -60,22 +60,32 @@ struct ConstantMergedTensorDescriptor
} }
template <index_t IDim> template <index_t IDim>
__host__ __device__ static constexpr index_t GetStride(Number<IDim>) __host__ __device__ static constexpr auto GetStride(Number<IDim>)
{ {
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}), static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
"wrong! stride of a merged dimension is undefined"); "wrong! stride of a merged dimension is undefined");
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front(); constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_original>{}); return OriginalTensorDesc::GetStride(Number<idim_original>{});
} }
// this is a hack to return the stride of the last original dimension of a merged dimension
// TODO: refactor this once the concept of "dimension" is used
template <index_t IDim>
__host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number<IDim>)
{
constexpr auto idim_last_original = std::get<IDim>(mOriginalDimMergeSeqs).Back();
return OriginalTensorDesc::GetStride(Number<idim_last_original>{});
}
__host__ __device__ static constexpr auto GetLengths() __host__ __device__ static constexpr auto GetLengths()
{ {
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{}; return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return OriginalTensorDesc::GetElementSize(); return OriginalTensorDesc::GetElementSize();
} }
...@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor ...@@ -174,6 +184,13 @@ struct ConstantMergedTensorDescriptor
return packed_desc.GetMultiIndexFrom1dIndex(id); return packed_desc.GetMultiIndexFrom1dIndex(id);
} }
__host__ __device__ static constexpr auto Pack()
{
constexpr auto lengths = GetLengths();
constexpr auto strides = calculate_tensor_strides_packed(lengths);
return ConstantTensorDescriptor<decltype(lengths), decltype(strides)>{};
}
}; };
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs> template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
......
...@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor ...@@ -43,23 +43,15 @@ struct ConstantTensorDescriptor
return Sequence<IDim>{}; return Sequence<IDim>{};
} }
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetNumOfDimension() { return Number<nDim>{}; }
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; } __host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
template <index_t I> __host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; }
__host__ __device__ static constexpr index_t GetLength(Number<I>)
{
return Lengths::Get(Number<I>{});
}
template <index_t I> __host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; }
__host__ __device__ static constexpr index_t GetStride(Number<I>)
{
return Strides::Get(Number<I>{});
}
struct lambda_AreDimensionsContinuous struct lambda_AreDimensionsContinuous
{ {
...@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor ...@@ -102,17 +94,18 @@ struct ConstantTensorDescriptor
return false; return false;
} }
__host__ __device__ static constexpr index_t GetElementSize() __host__ __device__ static constexpr auto GetElementSize()
{ {
return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{}); return Number<accumulate_on_sequence(
Lengths{}, math::multiplies<index_t>{}, Number<1>{})>{};
} }
__host__ __device__ static constexpr index_t GetElementSpace() __host__ __device__ static constexpr auto GetElementSpace()
{ {
constexpr index_t element_space_unaligned = accumulate_on_sequence( constexpr index_t element_space_unaligned = accumulate_on_sequence(
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{}); (GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
return element_space_unaligned; return Number<element_space_unaligned>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor ...@@ -156,13 +149,14 @@ struct ConstantTensorDescriptor
} }
template <index_t... Is> template <index_t... Is>
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>) __host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence<Is...>)
{ {
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
constexpr auto multi_id = Sequence<Is...>{}; constexpr auto multi_id = Sequence<Is...>{};
return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{}); return Number<accumulate_on_sequence(
multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{})>{};
} }
// emulate constexpr lambda // emulate constexpr lambda
...@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor ...@@ -369,6 +363,12 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{}; return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
} }
template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldIntervals...>)
{
return Fold(Number<IDim>{}, Number<FoldIntervals>{}...);
}
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
template <index_t FirstUnfoldDim, index_t LastUnfoldDim> template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>) __host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
...@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor ...@@ -407,6 +407,12 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{}; return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
} }
__host__ __device__ static constexpr auto Pack()
{
using packed_strides = decltype(calculate_tensor_strides_packed(Lengths{}));
return ConstantTensorDescriptor<Lengths, packed_strides>{};
}
template <class MapNew2Old> template <class MapNew2Old>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
{ {
...@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor ...@@ -414,14 +420,12 @@ struct ConstantTensorDescriptor
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{}; decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
} }
#if 0 // require sequence_sort, which is not implemented yet
template <class MapOld2New> template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{ {
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})), return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{} decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
} }
#endif
}; };
template <class Lengths> template <class Lengths>
...@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s, ...@@ -451,7 +455,7 @@ print_ConstantTensorDescriptor(const char* s,
{ {
constexpr index_t ndim = sizeof...(Lengths); constexpr index_t ndim = sizeof...(Lengths);
static_assert(ndim > 0 && ndim <= 10, "wrong!"); static_assert(ndim > 0 && ndim <= 12, "wrong!");
static_if<ndim == 1>{}([&](auto) { static_if<ndim == 1>{}([&](auto) {
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...); printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
...@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s, ...@@ -523,6 +527,26 @@ print_ConstantTensorDescriptor(const char* s,
Lengths..., Lengths...,
Strides...); Strides...);
}); });
static_if<ndim == 11>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
"%u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
static_if<ndim == 12>{}([&](auto) {
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
"%u %u %u %u "
"%u %u %u}\n",
s,
ndim,
Lengths...,
Strides...);
});
} }
} // namespace ck } // namespace ck
......
#ifndef CK_TENSOR_COORDINATE_HPP
#define CK_TENSOR_COORDINATE_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace ck {
template <class TensorDesc>
struct NormalTensorCoordinate
{
using type = NormalTensorCoordinate;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
__host__ __device__ constexpr NormalTensorCoordinate(Array<index_t, nDim> tensor_index)
: mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)}
{
}
template <class... Xs>
__host__ __device__ constexpr NormalTensorCoordinate(Xs... xs)
: NormalTensorCoordinate(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes);
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
// reposition point of origin, and return compensated offset.
// This is a hack to reduce index calculation during looping over
// a tensor whose origin is this TensorCoordinate. It does so, by spitting
// out the run-time offset to the pointer (to the tensor data) held by this
// TensorCoordiante, so the caller can add the offset into the run-time pointer of
// the data, so only 1 run-time variable (update pointer) is needed, instead
// of 2 run-time variables (old pointer and this offset)
// TODO: after introducing the concept of "run-time tensor view", which contains the
// run-time pointer to the data, always keep track of the pointer, instead of both
// offset and the pointer. This also bring additional benefit that we don't need to
// worry the offset might underflow (because offset is unsigned integer) when updating it.
__host__ __device__ constexpr index_t RepositOrigin()
{
index_t offset_diff = mOffset;
mOffset = 0;
return offset_diff;
}
private:
index_t mOffset;
};
template <class TensorDesc>
struct MergedTensorCoordinate
{
using type = MergedTensorCoordinate;
using tensor_desc_type = TensorDesc;
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
static constexpr index_t nOriginalDim =
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
__host__ __device__ constexpr MergedTensorCoordinate(Array<index_t, nDim> tensor_index)
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
{
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto idim) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mOriginalIndex, partial_original_dims));
});
// complete offset
mOffset =
accumulate_on_array(mPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
}
template <class... Xs>
__host__ __device__ constexpr MergedTensorCoordinate(Xs... xs)
: MergedTensorCoordinate(Array<index_t, nDim>{xs...})
{
}
__host__ __device__ constexpr index_t GetOffset() const { return mOffset; }
template <class IDim, class T, bool PositiveDirection>
__host__ __device__ void
MoveOnDimension(IDim idim_, T step_size, integral_constant<bool, PositiveDirection>)
{
constexpr auto idim = idim_;
// if step_size is known at compile time
static_if<is_static<T>::value>{}(
[&](auto) { static_if<T{} == 0>{}([&](auto) { return; }); });
// update original index
static_if<tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
constexpr auto partial_original_dims =
tensor_desc_type::GetContainedOriginalDimensions(idim);
constexpr index_t ndim_partial_original = partial_original_dims.GetSize();
constexpr auto partial_original_desc =
tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims);
const auto partial_original_step_sizes =
partial_original_desc.GetMultiIndexFrom1dIndex(step_size);
// update partial original multi-id
auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims);
static_if<PositiveDirection>{}([&](auto) {
partial_original_id += partial_original_step_sizes;
bool carry = false;
// do carry check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(carry)
{
++partial_original_id(i);
}
carry = false;
if(partial_original_id[i] >= partial_original_desc.GetLength(i))
{
partial_original_id(i) -= partial_original_desc.GetLength(i);
carry = true;
}
});
}).Else([&](auto) {
// shift up multi-id to avoid unsigned integer underflow during intermediate
// calculations. After the shift, should have new_multi_id[...] >= 1
partial_original_id +=
partial_original_desc.GetLengths() - partial_original_step_sizes;
bool borrow = false;
// do borrow check in reversed order, starting from lowest dimension
// don't check the highest dimension
static_for<0, ndim_partial_original, 1>{}([&](auto IReverse) {
constexpr index_t i = ndim_partial_original - 1 - IReverse;
if(borrow)
{
--partial_original_id(i);
}
borrow = false;
if(partial_original_id[i] < partial_original_desc.GetLength(i))
{
partial_original_id(i) += partial_original_desc.GetLength(i);
borrow = true;
}
});
// shift back down multi-id
// here, should have new_multi_id[...] >= GetLengths()
partial_original_id = partial_original_id - partial_original_desc.GetLengths();
});
// update "mOriginalIndex"
static_for<0, ndim_partial_original, 1>{}([&](auto I) {
constexpr auto idim_original = partial_original_dims[I];
mOriginalIndex(idim_original) = partial_original_id[I];
});
// calculate new partial offset on this merged dimension
const index_t old_partial_offset = mPartialOffsets[idim];
mPartialOffsets(idim) =
partial_original_desc.GetOffsetFromMultiIndex(partial_original_id);
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset;
}).Else([&](auto fwd) {
static_if<PositiveDirection>{}([&](auto) {
mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim);
}).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); });
});
}
// T is Array or Sequence
template <class T>
__host__ __device__ type operator+=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim,
"wrong! the rank of step size doesn't match with that of tensor coordinate");
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
}
});
return *this;
}
template <class T>
__host__ __device__ type operator-=(T step_sizes)
{
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim,
"wrong! the rank of step size doesn't match with that of tensor coordinate");
static_for<0, nDim, 1>{}([&](auto idim) {
if(step_sizes[idim] != 0)
{
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
}
});
return *this;
}
template <class T>
__host__ __device__ constexpr type operator+(T step_sizes) const
{
type coord = *this;
coord += step_sizes;
return coord;
}
template <class T>
__host__ __device__ constexpr type operator-(T step_sizes) const
{
type coord = *this;
coord -= step_sizes;
return coord;
}
__host__ __device__ static constexpr index_t RepositOrigin() { return 0; }
private:
// Allocate register memory for all merged dimensions and normal dimensions.
// However, only those merged dimensions, whose index will be involved in arithmetic
// after the construction of this TensorCoordinate (e.g. when user move a slicing
// window on the merged dimension), will use these register memory.
// Let's hope compiler will optimize away those register memory allocated for normal
// dimensions, and those merged dimensions, that would never be involved in index
// arithmetic after construction of TensorCoordinate.
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
// count on compiler to optimize way those register memory for us
Array<index_t, nOriginalDim> mOriginalIndex;
Array<index_t, nDim> mPartialOffsets;
// complete offset
index_t mOffset;
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <class input_type>
struct mfma_info
{
};
template <>
struct mfma_info<float>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_blks_wave = 2;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t wave_size = 64;
};
template <>
struct mfma_info<half>
{
static const index_t group_size = 4;
static const index_t num_groups_blk = 4;
static const index_t num_blks_wave = 2;
static const index_t num_regs_blk = group_size * num_groups_blk;
static const index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static const index_t num_threads_blk = 32;
static const index_t m = 32;
static const index_t n = 32;
static const index_t k = 4;
static const index_t wave_size = 64;
};
template <>
struct mfma_info<ushort>
{
static const index_t group_size = 4;
static const index_t num_groups_blk = 4;
static const index_t num_blks_wave = 2;
static const index_t num_regs_blk = group_size * num_groups_blk;
static const index_t num_regs_xdlops = num_regs_blk * num_blks_wave;
static const index_t num_threads_blk = 32;
static const index_t m = 32;
static const index_t n = 32;
static const index_t k = 2;
static const index_t wave_size = 64;
};
// emulate xdlops
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info,
class FloatA,
class FloatB,
class FloatC>
__device__ void WaveWiseGemmMx64(const FloatA* const __restrict__ p_a_wave,
const FloatB* const __restrict__ p_b_wave,
FloatC* const __restrict__ p_c_thread)
{
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
const index_t blk_id = laneId / mfma_info::num_threads_blk;
const index_t lane_b = laneId % mfma_info::num_threads_blk;
for(index_t k = 0; k < K; ++k)
{
for(index_t b = 0; b < MPerWave / mfma_info::m; ++b)
{
index_t a_off = k * M + b * mfma_info::m;
index_t b_off = k * N;
// pseudo mfma
for(index_t n = 0; n < mfma_info::num_blks_wave; ++n)
{
index_t output_m = mfma_info::num_regs_blk;
for(index_t m = 0; m < output_m; ++m)
{
index_t aindex = m % mfma_info::group_size + blk_id * mfma_info::group_size +
m / mfma_info::group_size *
(mfma_info::group_size * mfma_info::num_blks_wave) +
a_off; // A is transposed
index_t bindex = b_off + lane_b + n * mfma_info::num_threads_blk;
p_c_thread[m + n * output_m + b * output_m * mfma_info::num_blks_wave] +=
math::inner_product_with_conversion<FloatC>{}(p_a_wave[aindex],
p_b_wave[bindex]);
}
}
}
}
}
#if 0
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(const float* const __restrict__ p_a_wave,
const float* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
static_assert(GemmDataPerReadA == 1 && GemmDataPerReadB == 1, "GemmDataPerReadA/B != 1");
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k)
{
float reg_a = p_a_wave[k * M + laneId];
float reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x1f32<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<half, 4>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 4)
{
typename vector_type<half, 4>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<half, 4>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x4f16<MPerWave>(reg_a, reg_b, reg_c);
}
}
template <index_t M,
index_t N,
index_t K,
index_t MPerWave,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
class mfma_info>
__device__ void WaveWiseGemmMx64_xdlops(
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_a_wave,
const typename vector_type<ushort, 2>::MemoryType* const __restrict__ p_b_wave,
float* const __restrict__ p_c_thread)
{
static_assert(MPerWave == 32 || MPerWave == 64, "only support MPerWave = 32/64");
const index_t laneId = threadIdx.x % mfma_info::wave_size;
for(index_t k = 0; k < K; k += mfma_info::k / 2)
{
typename vector_type<ushort, 2>::MemoryType reg_a = p_a_wave[k * M + laneId];
typename vector_type<ushort, 2>::MemoryType reg_b = p_b_wave[k * N + laneId];
float32_t* reg_c = reinterpret_cast<float32_t*>(p_c_thread);
gcnasm_mfma_f32_32x32x2bf16<MPerWave>(reg_a, reg_b, reg_c);
}
}
#endif
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class mfma_info,
bool EnableXdlops,
index_t GemmMPerWave,
index_t GemmNPerWave,
index_t GemmMWaves,
index_t GemmNWaves,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct MatrixIndex
{
index_t row;
index_t col;
};
struct OutputLayout_t
{
static constexpr index_t M3 = GemmMPerWave / mfma_info::m;
static constexpr index_t M2 = mfma_info::num_groups_blk;
static constexpr index_t M1 = mfma_info::num_blks_wave;
static constexpr index_t M0 = mfma_info::group_size;
};
index_t mMyWaveOffsetA;
index_t mMyWaveOffsetB;
OutputLayout_t OutputLayout;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops()
{
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
static_assert(GemmNPerWave == 64, "Only support GemmNPerWave == 64 for xdlops");
static_assert(GemmMPerWave == 32 || GemmMPerWave == 64,
"Only support GemmMPerWave == 32 or 64 for xdlops");
static_assert(GemmMPerWave * GemmMWaves == M, "GemmMWaves * GemmMPerWave != M");
static_assert(GemmNPerWave * GemmNWaves == N, "GemmNWaves * GemmNPerWave != N");
static_assert(BlockSize == GemmMWaves * GemmNWaves * 64,
"BlockSize != GemmMWaves * GemmNWaves * 64\n");
const index_t waveId = get_thread_local_1d_id() / mfma_info::wave_size;
const index_t waveId_m = waveId / GemmNWaves;
const index_t waveId_n = waveId % GemmNWaves;
mMyWaveOffsetA = waveId_m * GemmMPerWave;
mMyWaveOffsetB = waveId_n * GemmNPerWave;
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
// static_if<EnableXdlops>{}([&](auto) {
// WaveWiseGemmMx64_xdlops<M,
// N,
// K,
// GemmMPerWave,
// GemmDataPerReadA,
// GemmDataPerReadB,
// mfma_info>(
// &p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
// }).Else([&](auto) {
WaveWiseGemmMx64<M, N, K, GemmMPerWave, GemmDataPerReadA, GemmDataPerReadB, mfma_info>(
&p_a_block[mMyWaveOffsetA], &p_b_block[mMyWaveOffsetB], p_c_thread);
// });
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t i)
{
const index_t laneId = get_thread_local_1d_id() % mfma_info::wave_size;
const index_t waveId = get_thread_local_1d_id() / mfma_info::wave_size;
const index_t col_i = i % mfma_info::num_blks_wave;
const index_t col = waveId % GemmNWaves * mfma_info::wave_size +
laneId % mfma_info::num_threads_blk +
col_i * mfma_info::num_threads_blk;
const index_t row_i = i / mfma_info::num_blks_wave;
const index_t row = waveId / GemmNWaves * GemmMPerWave +
laneId / mfma_info::num_threads_blk * mfma_info::group_size +
row_i * mfma_info::num_threads_blk;
return MatrixIndex{row, col};
}
__device__ constexpr auto GetThreadMatrixCDescriptor() const
{
constexpr index_t num_xdlops = GemmMPerWave / mfma_info::m;
return make_ConstantMatrixDescriptor_packed(
Number<mfma_info::num_regs_xdlops * num_xdlops>{}, Number<1>{});
}
};
} // namespace ck
#endif
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "ConstantMatrixDescriptor.hpp"
#include "float_types.h" #include "math.hpp"
namespace ck { namespace ck {
...@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -37,58 +37,18 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
constexpr auto src_mtx = SrcMatrix{}; constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{}; constexpr auto dst_mtx = DstMatrix{};
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
// Depending upon datatype i.e float/half/bfloat16, carry out data movement for(index_t i = 0; i < NRow; ++i)
// in appropriate vectorized form {
// float - 4, half - 4, bfloat16 - 2 for(index_t j = 0; j < NCol; j += DataPerRead)
static_if<std::is_same<Float, float>::value>{}([&](auto) {
using vector_t = typename vector_type<float, DataPerRead>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{ {
for(index_t j = 0; j < NCol; j += DataPerRead) const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
{ const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) = *reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
} }
}
}).Else([&](auto) {
static_if<std::is_same<Float, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using vector_t = typename vector_type<Float, 4>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*4]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*4]);
}
}
}).Else([&](auto) {
using vector_t = typename vector_type<Float, 2>::MemoryType;
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; ++j)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index*2]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index*2]);
}
}
});
});
} }
template <class MatrixA, template <class MatrixA,
...@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -119,7 +79,6 @@ __device__ void threadwise_gemm(MatrixA,
constexpr index_t N = c_mtx.NCol(); constexpr index_t N = c_mtx.NCol();
constexpr index_t K = a_mtx.NRow(); // A is transposed constexpr index_t K = a_mtx.NRow(); // A is transposed
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
for(index_t i = 0; i < M; ++i) for(index_t i = 0; i < M; ++i)
...@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -130,32 +89,8 @@ __device__ void threadwise_gemm(MatrixA,
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j); const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j); const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
static_if<std::is_same<FloatA, float>::value>{}([&](auto) { p_c_thread[cindex] += math::inner_product_with_conversion<FloatC>{}(
p_c_thread[cindex] += CVT_FLOAT2ACCUM(p_a_thread[aindex]) * p_a_thread[aindex], p_b_thread[bindex]);
CVT_FLOAT2ACCUM(p_b_thread[bindex]);
}).Else([&](auto) {
static_if<std::is_same<FloatA, half>::value>{}([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 4; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*4 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*4 + v]);
}
p_c_thread[cindex] += acc;
}).Else([&](auto) {
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
float acc = 0.0;
for(index_t v = 0; v < 2; ++v)
{
acc += CVT_FLOAT2ACCUM(p_a_thread[aindex*2 + v]) *
CVT_FLOAT2ACCUM(p_b_thread[bindex*2 + v]);
}
p_c_thread[cindex] += acc;
});
});
} }
} }
} }
......
...@@ -9,7 +9,8 @@ namespace ck { ...@@ -9,7 +9,8 @@ namespace ck {
template <class TData, index_t NSize> template <class TData, index_t NSize>
struct Array struct Array
{ {
using Type = Array<TData, NSize>; using Type = Array<TData, NSize>;
using data_type = TData;
static constexpr index_t nSize = NSize; static constexpr index_t nSize = NSize;
...@@ -20,7 +21,7 @@ struct Array ...@@ -20,7 +21,7 @@ struct Array
{ {
} }
__host__ __device__ constexpr index_t GetSize() const { return NSize; } __host__ __device__ static constexpr index_t GetSize() { return NSize; }
template <index_t I> template <index_t I>
__host__ __device__ constexpr TData operator[](Number<I>) const __host__ __device__ constexpr TData operator[](Number<I>) const
...@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, ...@@ -208,6 +209,21 @@ __host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData,
return result; return result;
} }
// Array += Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator+=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a + b;
return a;
}
// Array -= Array
template <class TData, index_t NSize>
__host__ __device__ constexpr auto operator-=(Array<TData, NSize>& a, Array<TData, NSize> b)
{
a = a - b;
return a;
}
// Array = Array + Sequence // Array = Array + Sequence
template <class TData, index_t NSize, index_t... Is> template <class TData, index_t NSize, index_t... Is>
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b) __host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
......
...@@ -6,41 +6,63 @@ ...@@ -6,41 +6,63 @@
namespace ck { namespace ck {
template <class Seq> template <index_t...>
struct Sequence;
template <class Seq, index_t I>
struct sequence_split;
template <class>
struct sequence_reverse;
template <class>
struct sequence_map_inverse;
template <class>
struct is_valid_sequence_map; struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is> template <index_t... Is>
struct Sequence struct Sequence
{ {
using Type = Sequence; using Type = Sequence;
using data_type = index_t;
static constexpr index_t mSize = sizeof...(Is); static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr index_t GetSize() { return mSize; } __host__ __device__ static constexpr auto GetSize() { return Number<mSize>{}; }
template <index_t I> __host__ __device__ static constexpr index_t GetImpl(index_t I)
__host__ __device__ static constexpr index_t Get(Number<I>)
{ {
static_assert(I < mSize, "wrong! I too large");
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0 // the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0}; const index_t mData[mSize + 1] = {Is..., 0};
return mData[I]; return mData[I];
} }
template <index_t I> template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const __host__ __device__ static constexpr auto Get(Number<I>)
{ {
return Number<Get(Number<I>{})>{}; static_assert(I < mSize, "wrong! I too large");
return Number<GetImpl(Number<I>{})>{};
} }
// make sure I is constepxr __host__ __device__ static constexpr auto Get(index_t I) { return GetImpl(I); }
__host__ __device__ constexpr index_t operator[](index_t I) const
template <index_t I>
__host__ __device__ constexpr auto operator[](Number<I>) const
{ {
const index_t mData[mSize + 1] = {Is..., 0}; return Get(Number<I>{});
return mData[I];
} }
// make sure I is constepxr if you want a constexpr return type
__host__ __device__ constexpr index_t operator[](index_t I) const { return GetImpl(I); }
template <index_t... IRs> template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{ {
...@@ -52,23 +74,38 @@ struct Sequence ...@@ -52,23 +74,38 @@ struct Sequence
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::Get(Number<IRs>{})...>{};
} }
__host__ __device__ static constexpr auto Reverse(); // MapOld2New is Sequence<...>
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::GetSize() == GetSize(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
}
__host__ __device__ static constexpr index_t Front() __host__ __device__ static constexpr auto Reverse()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; return typename sequence_reverse<Type>::type{};
return mData[0];
} }
__host__ __device__ static constexpr index_t Back() __host__ __device__ static constexpr auto Front()
{ {
const index_t mData[mSize + 1] = {Is..., 0}; static_assert(mSize > 0, "wrong!");
return mData[mSize - 1]; return Get(Number<0>{});
}
__host__ __device__ static constexpr auto Back()
{
static_assert(mSize > 0, "wrong!");
return Get(Number<mSize - 1>{});
} }
__host__ __device__ static constexpr auto PopFront(); __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
__host__ __device__ static constexpr auto PopBack(); __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>) __host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
...@@ -107,7 +144,16 @@ struct Sequence ...@@ -107,7 +144,16 @@ struct Sequence
} }
template <index_t I, index_t X> template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>); __host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <class F> template <class F>
__host__ __device__ static constexpr auto Transform(F f) __host__ __device__ static constexpr auto Transform(F f)
...@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> ...@@ -126,48 +172,63 @@ struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
using type = Sequence<Xs..., Ys...>; using type = Sequence<Xs..., Ys...>;
}; };
// arithmetic sqeuence // generate sequence
template <index_t IBegin, index_t NSize, index_t Increment> template <index_t IBegin, index_t NRemain, class F>
struct arithmetic_sequence_gen_impl struct sequence_gen_impl
{ {
static constexpr index_t NSizeLeft = NSize / 2; static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = typename sequence_merge< using type =
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::type, typename sequence_merge<typename sequence_gen_impl<IBegin, NRemainLeft, F>::type,
typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment, typename sequence_gen_impl<IMiddle, NRemainRight, F>::type>::type;
NSize - NSizeLeft,
Increment>::type>::type;
}; };
template <index_t IBegin, index_t Increment> template <index_t I, class F>
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment> struct sequence_gen_impl<I, 1, F>
{ {
using type = Sequence<IBegin>; static constexpr index_t Is = F{}(Number<I>{});
using type = Sequence<Is>;
}; };
template <index_t IBegin, index_t Increment> template <index_t I, class F>
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment> struct sequence_gen_impl<I, 0, F>
{ {
using type = Sequence<>; using type = Sequence<>;
}; };
template <index_t NSize, class F>
struct sequence_gen
{
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
// arithmetic sequence
template <index_t IBegin, index_t IEnd, index_t Increment> template <index_t IBegin, index_t IEnd, index_t Increment>
struct arithmetic_sequence_gen struct arithmetic_sequence_gen
{ {
using type = typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::type; struct F
{
__host__ __device__ constexpr index_t operator()(index_t i) const
{
return i * Increment + IBegin;
}
};
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
}; };
// uniform sequence // uniform sequence
template <index_t NSize, index_t I> template <index_t NSize, index_t I>
struct uniform_sequence_gen struct uniform_sequence_gen
{ {
struct return_constant struct F
{ {
__host__ __device__ constexpr index_t operator()(index_t) const { return I; } __host__ __device__ constexpr index_t operator()(index_t) const { return I; }
}; };
using type = decltype( using type = typename sequence_gen<NSize, F>::type;
typename arithmetic_sequence_gen<0, NSize, 1>::type{}.Transform(return_constant{}));
}; };
// reverse inclusive scan (with init) sequence // reverse inclusive scan (with init) sequence
...@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -236,6 +297,7 @@ struct sequence_reverse<Sequence<I0, I1>>
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
// not implemented yet, always return true
static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{}; static constexpr integral_constant<bool, true> value = integral_constant<bool, true>{};
// TODO: add proper check for is_valid, something like: // TODO: add proper check for is_valid, something like:
...@@ -244,6 +306,34 @@ struct is_valid_sequence_map ...@@ -244,6 +306,34 @@ struct is_valid_sequence_map
// typename sequence_sort<Seq>::SortedSeqType>{}; // typename sequence_sort<Seq>::SortedSeqType>{};
}; };
template <class X2Y, class WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
private:
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::Get(Number<XBegin>{}), Number<XBegin>{});
public:
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::type;
};
template <class X2Y, class WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
template <class X2Y>
struct sequence_map_inverse
{
using type =
typename sequence_map_inverse_impl<X2Y,
typename uniform_sequence_gen<X2Y::GetSize(), 0>::type,
0,
X2Y::GetSize()>::type;
};
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{ {
...@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>) ...@@ -355,8 +445,8 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
template <class Seq> template <class Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq) __host__ __device__ constexpr auto sequence_pop_back(Seq)
{ {
static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!"); static_assert(Seq::GetSize() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq{}.Reverse()).Reverse(); return sequence_pop_front(Seq::Reverse()).Reverse();
} }
template <class F, index_t... Xs> template <class F, index_t... Xs>
...@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I ...@@ -396,37 +486,6 @@ __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<I
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse(); return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
} }
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
{
return sequence_pop_front(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
{
return sequence_pop_back(Type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
{
return typename sequence_reverse<Sequence<Is...>>::type{};
}
template <index_t... Is>
template <index_t I, index_t X>
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
{
static_assert(I < GetSize(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::SeqType0{};
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <index_t... Xs> template <index_t... Xs>
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>) __host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
{ {
......
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#ifndef BFLOAT16_DEVICE_HPP #ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP #define BFLOAT16_DEVICE_HPP
#define __HIP_PLATFORM_HCC__ 1
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment