Commit 7a89684f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent eafdabba
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include "ConstantMatrixDescriptor.hip.hpp" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_generic_tensor_slice_op.hip.hpp" #include "blockwise_generic_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -42,7 +41,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -42,7 +41,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// this is a mess // this is a mess
// TODO: fidn more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) % static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
...@@ -144,46 +143,34 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -144,46 +143,34 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{}); Number<mod_conv::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
#if 0
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc), decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()), decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_C_K, WeiBlockCopySubLengths_C_K,
WeiBlockCopyClusterLengths_C_K, WeiBlockCopyClusterLengths_C_K,
Sequence<0, 1>, // thread_arrange_order [C, K] Sequence<0, 1>, // thread_arrange_order [C, K]
Sequence<0, 1>, // src_access_order [C, K] Sequence<0, 1>, // src_access_order [C, K]
Sequence<0, 1>, // dst_access_order [C, K] Sequence<0, 1>, // dst_access_order [C, K]
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
#else
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global},
{0, 0});
#endif
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[CPerBlock, KPerBlock] is in LDS // a_mtx[CPerBlock, KPerBlock] is in LDS
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS // b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
// register // register
constexpr auto a_c_k_block_mtx_desc = constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
Number<KPerBlock>{},
Number<wei_c_k_block_desc.GetStride(I0)>{});
constexpr auto b_c_n1bn2_block_mtx_desc = constexpr auto b_c_n1bn2_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, make_ConstantMatrixDescriptor(Number<CPerBlock>{},
...@@ -228,7 +215,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -228,7 +215,7 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
}; };
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDstDataPerWrite_N2, constexpr index_t max_align = mod_conv::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -261,18 +248,8 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -261,18 +248,8 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()]; blockwise_in_copy.Run(p_in_block_on_global, p_in_block_double);
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()]; blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block_double);
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global,
p_in_register_clipboard);
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
p_wei_register_clipboard);
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
p_in_block_double);
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
p_wei_block_double);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -413,7 +390,8 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw ...@@ -413,7 +390,8 @@ struct GridwiseConvolutionImplicitGemm_v3_lds_double_buffer_nchw_cyxk_nkhw
p_out_thread_on_global, p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{}); arithmetic_sequence_gen<0, 8, 1>::SeqType{},
Number<1>{});
} }
} }
}; };
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include "ConstantMatrixDescriptor.hip.hpp" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_generic_tensor_slice_op.hip.hpp" #include "blockwise_generic_tensor_slice_op.hip.hpp"
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -42,7 +41,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -42,7 +41,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_out_global) const
{ {
// this is a mess // this is a mess
// TODO: fidn more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); static_assert(N2 == GemmNPerThreadSubC, "wrong!");
static_assert((N1 * N2 * BPerBlock) % static_assert((N1 * N2 * BPerBlock) %
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
...@@ -147,13 +146,12 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -147,13 +146,12 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Sequence<CPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{}); Number<mod_conv::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
#if 1
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float, Float,
decltype(wei_c_k_global_desc), decltype(wei_c_k_global_desc),
...@@ -167,15 +165,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -167,15 +165,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
WeiBlockCopyDataPerAccess_K>( WeiBlockCopyDataPerAccess_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
#else
Blockwise2dTensorCopy3<BlockSize,
Float,
decltype(wei_c_k_global_desc),
decltype(wei_c_k_block_desc),
decltype(wei_c_k_block_desc.GetLengths()),
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global},
{0, 0});
#endif
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -219,8 +208,17 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -219,8 +208,17 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_asm(Xs...);
#endif
};
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDstDataPerWrite_N2, constexpr index_t max_align = mod_conv::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDataPerAccess_K, WeiBlockCopyDataPerAccess_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -264,7 +262,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -264,7 +262,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
__syncthreads(); __syncthreads();
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread); run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads(); __syncthreads();
} }
...@@ -294,7 +292,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -294,7 +292,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
__syncthreads(); __syncthreads();
// move on C: C_N1_B_N2, C_K
blockwise_in_copy.MoveSlicingWindowOnSourceTensor( blockwise_in_copy.MoveSlicingWindowOnSourceTensor(
I0, Number<CPerBlock>{}, True); I0, Number<CPerBlock>{}, True);
...@@ -366,7 +363,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw ...@@ -366,7 +363,8 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
p_out_thread_on_global, p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{}); arithmetic_sequence_gen<0, 8, 1>::SeqType{},
Number<1>{});
} }
} }
}; };
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp" #include "threadwise_generic_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -165,12 +165,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -165,12 +165,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<mod_conv::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
#if 1
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float, Float,
...@@ -185,22 +184,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -185,22 +184,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
#else
constexpr auto map_k_e_2_e_k = Sequence<1, 0>{};
auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
BlockSize,
Float,
decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths().ReorderGivenNew2Old(map_k_e_2_e_k)),
decltype(WeiBlockCopySubLengths_E_K::ReorderGivenNew2Old(map_k_e_2_e_k)),
decltype(WeiBlockCopyClusterLengths_E_K::ReorderGivenNew2Old(map_k_e_2_e_k)),
Sequence<1, 0>, // MapDst2Src
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>({k_block_data_on_global, 0}, {0, 0});
#endif
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -254,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -254,7 +237,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
}; };
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDstDataPerWrite_N2, constexpr index_t max_align = mod_conv::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -273,18 +256,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -273,18 +256,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
#if 0
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
const Float* p_wei_block_on_global = p_wei_global; const Float* p_wei_block_on_global = p_wei_global;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "blockwise_gemm.hip.hpp" #include "blockwise_gemm.hip.hpp"
#include "threadwise_generic_tensor_slice_op.hip.hpp" #include "threadwise_generic_tensor_slice_op.hip.hpp"
// define B = merge(N, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Float, class Float,
...@@ -30,10 +30,16 @@ template <index_t GridSize, ...@@ -30,10 +30,16 @@ template <index_t GridSize,
index_t GemmDataPerReadB, index_t GemmDataPerReadB,
class InBlockCopySubLengths_E_N1_B_N2, class InBlockCopySubLengths_E_N1_B_N2,
class InBlockCopyClusterLengths_E_N1_B_N2, class InBlockCopyClusterLengths_E_N1_B_N2,
class InBlockCopyThreadClusterArrangeOrder,
class InBlockCopySrcAccessOrder,
class InBlockCopyDstAccessOrder,
index_t InBlockCopySrcDataPerRead_B, index_t InBlockCopySrcDataPerRead_B,
index_t InBlockCopyDstDataPerWrite_N2, index_t InBlockCopyDstDataPerWrite_N2,
class WeiBlockCopySubLengths_E_K, class WeiBlockCopySubLengths_E_K,
class WeiBlockCopyClusterLengths_E_K, class WeiBlockCopyClusterLengths_E_K,
class WeiBlockCopyThreadClusterArrangeOrder,
class WeiBlockCopySrcAccessOrder,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
...@@ -146,19 +152,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -146,19 +152,20 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
// input blockwise copy // input blockwise copy
// slice a merged tensor, reorder and copy to a normal tensor // slice a merged tensor, reorder and copy to a normal tensor
// this copy operator already has blockwise offset built-in // this copy operator already has blockwise offset built-in
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1< auto blockwise_in_copy =
BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
Float, Float,
decltype(in_e_n1_b_n2_global_merged_desc), decltype(in_e_n1_b_n2_global_merged_desc),
decltype(in_e_n1_b_n2_block_desc), decltype(in_e_n1_b_n2_block_desc),
decltype(in_e_n1_b_n2_block_desc.GetLengths()), decltype(in_e_n1_b_n2_block_desc.GetLengths()),
InBlockCopySubLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopyClusterLengths_E_N1_B_N2,
Sequence<0, 1, 3, 2>, // thread_arrange_order [E, N1, N2, B] InBlockCopyThreadClusterArrangeOrder,
Sequence<0, 1, 3, 2>, // src_access_order [E, N1, N2, B] InBlockCopySrcAccessOrder,
Sequence<0, 1, 2, 3>, // dst_access_order [E, N1, B, N2] InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); InBlockCopyDstDataPerWrite_N2>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
...@@ -169,7 +176,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -169,7 +176,7 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<EPerBlock, KPerBlock>{},
Number<mod_conv::max(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<mod_conv::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
// operator for blockwise copy of weight into LDS // operator for blockwise copy of weight into LDS
// slice a tensor, and copy it into another tensor // slice a tensor, and copy it into another tensor
...@@ -182,9 +189,9 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -182,9 +189,9 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopyClusterLengths_E_K,
Sequence<1, 0>, // thread_arrange_order [K, E] WeiBlockCopyThreadClusterArrangeOrder,
Sequence<1, 0>, // src_access_order [K, E] WeiBlockCopySrcAccessOrder,
Sequence<0, 1>, // dst_access_order [E, K] WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
...@@ -231,8 +238,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -231,8 +238,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_asm(Xs...);
#endif
};
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = mod_conv::max(InBlockCopyDstDataPerWrite_N2, constexpr index_t max_align = mod_conv::lcm(InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA, GemmDataPerReadA,
GemmDataPerReadB); GemmDataPerReadB);
...@@ -254,24 +270,13 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -254,24 +270,13 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
// do work // do work
for(index_t e = 0; e < E; e += EPerBlock) for(index_t e = 0; e < E; e += EPerBlock)
{ {
#if 0
if(e == 0 * EPerBlock && get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
blockwise_wei_copy.mThreadSrcOffset,
blockwise_wei_copy.mThreadDstOffset);
}
#endif
// marching slicing window // marching slicing window
blockwise_in_copy.Run(p_in_global, p_in_block); blockwise_in_copy.Run(p_in_global, p_in_block);
blockwise_wei_copy.Run(p_wei_global, p_wei_block); blockwise_wei_copy.Run(p_wei_global, p_wei_block);
__syncthreads(); __syncthreads();
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread); run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
__syncthreads(); __syncthreads();
...@@ -335,7 +340,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw ...@@ -335,7 +340,8 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
p_out_thread_on_global, p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
arithmetic_sequence_gen<0, 8, 1>::SeqType{}); arithmetic_sequence_gen<0, 8, 1>::SeqType{},
Number<1>{});
} }
} }
}; };
...@@ -8,7 +8,7 @@ struct integral_constant ...@@ -8,7 +8,7 @@ struct integral_constant
__host__ __device__ constexpr T Get() const { return value; } __host__ __device__ constexpr T Get() const { return value; }
}; };
template <class T, index_t X, index_t Y> template <class T, T X, T Y>
__host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>) __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>)
{ {
return integral_constant<T, X + Y>{}; return integral_constant<T, X + Y>{};
......
...@@ -62,7 +62,7 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -62,7 +62,7 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
#if 1 #if 1
ford<decltype(access_lengths)>{}([&](auto access_multi_id) { ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
auto data_multi_id_in_access_order = access_multi_id; auto data_multi_id_in_access_order = access_multi_id;
data_multi_id_in_access_order[nDim - 1] = access_multi_id[nDim - 1] * DataPerAccess; data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
const auto data_multi_id = const auto data_multi_id =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment