Commit 8c42225c authored by Chao Liu's avatar Chao Liu
Browse files

minor bug fix

parent 157491ab
...@@ -20,8 +20,8 @@ template <index_t GridSize, ...@@ -20,8 +20,8 @@ template <index_t GridSize,
typename OutGlobalDesc, typename OutGlobalDesc,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename InputLeftPads,
typename RightPads, typename InputRightPads,
index_t GemmMPerBlock, index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmKPerBlock, index_t GemmKPerBlock,
...@@ -98,8 +98,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -98,8 +98,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<C>{}, PassThrough<C>{},
Pad<Sequence<Y, X>, Pad<Sequence<Y, X>,
Sequence<0, 0>, Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>, Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>>{}),
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -121,14 +120,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -121,14 +120,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor // output tensor
constexpr auto out_n_k_hop_wop_global_desc = constexpr auto out_n_k_hop_wop_global_desc = transform_tensor_descriptor(
transform_tensor_descriptor(out_n_k_ho_wo_global_desc, out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(
PassThrough<N>{},
PassThrough<K>{}, PassThrough<K>{},
Pad<Sequence<Ho, Wo>, Pad<Sequence<Ho, Wo>, Sequence<0, 0>, Sequence<right_pad_ho, right_pad_wo>>{}),
Sequence<0, 0>,
Sequence<right_pad_ho, right_pad_wo>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
...@@ -154,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw ...@@ -154,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}), Pad<Sequence<Hi, Wi>, InputLeftPads, InputRightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "gridwise_gemm.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo)
// GEMM_M = K
// GEMM_N = N * Ho * Wo
// GEMM_K = C * Y * X
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
typename Float, typename Float,
typename AccFloat,
typename InGlobalDesc, typename InGlobalDesc,
typename WeiGlobalDesc, typename WeiGlobalDesc,
typename OutGlobalDesc, typename OutGlobalDesc,
...@@ -21,9 +22,9 @@ template <index_t GridSize, ...@@ -21,9 +22,9 @@ template <index_t GridSize,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads, typename RightPads,
index_t BPerBlock, index_t GemmNPerBlock,
index_t KPerBlock, index_t GemmMPerBlock,
index_t EPerBlock, index_t GemmKPerBlock,
index_t GemmMPerThreadSubC, index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC, index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster, index_t GemmMLevel0Cluster,
...@@ -31,14 +32,8 @@ template <index_t GridSize, ...@@ -31,14 +32,8 @@ template <index_t GridSize,
index_t GemmMLevel1Cluster, index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster, index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop, index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA, index_t GemmThreadGemmDataPerReadM,
index_t GemmDataPerReadB, index_t GemmThreadGemmDataPerReadN,
typename InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
typename WeiBlockCopySubLengths_E_K, typename WeiBlockCopySubLengths_E_K,
typename WeiBlockCopyClusterLengths_E_K, typename WeiBlockCopyClusterLengths_E_K,
typename WeiBlockCopyThreadClusterArrangeOrder, typename WeiBlockCopyThreadClusterArrangeOrder,
...@@ -46,6 +41,12 @@ template <index_t GridSize, ...@@ -46,6 +41,12 @@ template <index_t GridSize,
typename WeiBlockCopyDstAccessOrder, typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K,
typename InBlockCopySubLengths_E_B,
typename InBlockCopyClusterLengths_E_B,
typename InBlockCopyThreadClusterArrangeOrder,
typename InBlockCopySrcAccessOrder,
typename InBlockCopyDstAccessOrder,
index_t InBlockCopyDataPerAccess_B,
index_t OutThreadCopyDataPerAccess_B> index_t OutThreadCopyDataPerAccess_B>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{ {
...@@ -58,8 +59,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -58,8 +59,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc = constexpr auto wei_k_c_y_x_global_desc =
...@@ -94,23 +93,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -94,23 +93,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
"wrong! aligment requirement for vectorized global load of input tensor will " "wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"); "be violated");
// divide block work by [K, B] // weight tensor
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
"wrong! cannot divide work evenly among block"); unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr index_t KBlockWork = K / KPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// input tensor // input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc, in_n_c_hi_wi_global_desc,
make_tuple( make_tuple(
...@@ -127,54 +114,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -127,54 +114,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor( constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc, in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
// LDS mem
// be careful of LDS alignment
constexpr auto in_e_b_block_desc =
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
// input blockwise copy
auto blockwise_in_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(in_e_b_global_desc),
decltype(in_e_b_block_desc),
decltype(in_e_b_block_desc.GetLengths()),
InBlockCopySubLengths_E_B,
InBlockCopyClusterLengths_E_B,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
1,
1,
InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0});
// weight tensor
// global mem
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
// LDS // LDS
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
Sequence<EPerBlock, KPerBlock>{}, Sequence<GemmKPerBlock, GemmMPerBlock>{},
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{}); Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmThreadGemmDataPerReadM)>{});
// this check is ad-hoc // this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment // TODO: need to properly implement tensor descriptor with multiple alignment
// requirements // requirements
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, static_assert(wei_e_k_block_desc.GetStride(I0) % GemmThreadGemmDataPerReadM == 0,
"GemmDataPerReadA alignment requirement is not satisfied"); "GemmThreadGemmDataPerReadM alignment requirement is not satisfied");
// weight blockwise copy // weight blockwise copy
auto blockwise_wei_copy = auto blockwise_wei_copy =
...@@ -199,24 +155,24 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -199,24 +155,24 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS // a_mtx[GemmKPerBlock, GemmMPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS // b_mtx[EPerBlocl, GemmNPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in // c_mtx[GemmMPerBlock, GemmNPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc); constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check // sanity check
static_assert( static_assert(
KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 && GemmMPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, GemmNPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = constexpr index_t GemmMRepeat =
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); GemmMPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat = constexpr index_t GemmNRepeat =
BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster); GemmNPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
...@@ -235,14 +191,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -235,14 +191,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop, GemmKPerThreadLoop,
GemmDataPerReadA, GemmThreadGemmDataPerReadM,
GemmDataPerReadB>{}; GemmThreadGemmDataPerReadN>{};
// LDS allocation for input and weight: be careful of alignment // LDS allocation for input and weight: be careful of alignment
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B, constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
GemmDataPerReadA, GemmThreadGemmDataPerReadM,
GemmDataPerReadB); GemmThreadGemmDataPerReadN);
constexpr index_t in_block_space = constexpr index_t in_block_space =
math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align); math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align);
...@@ -266,8 +222,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -266,8 +222,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
} }
// LDS double buffer: main body // LDS double buffer: main body
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * GemmKPerBlock < E;
e_block_data_begin += 2 * EPerBlock) e_block_data_begin += 2 * GemmKPerBlock)
{ {
#pragma unroll #pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop) for(index_t iloop = 0; iloop < 2; ++iloop)
...@@ -287,8 +243,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -287,8 +243,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<GemmKPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<GemmKPerBlock, 0>{}, True);
__syncthreads(); __syncthreads();
...@@ -307,15 +263,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -307,15 +263,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail // LDS double buffer: tail
{ {
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0); constexpr bool has_two_iteration_left = (E % (2 * GemmKPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left if(has_two_iteration_left) // if has 2 iteration left
{ {
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_in_copy.MoveSrcSliceWindow(Sequence<GemmKPerBlock, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True); blockwise_wei_copy.MoveSrcSliceWindow(Sequence<GemmKPerBlock, 0>{}, True);
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment