"...composable_kernel.git" did not exist on "fb0dc35861056cbf08f68fd3208aa787e789230e"
Commit be4e3133 authored by Chao Liu's avatar Chao Liu
Browse files

add dynamic forward v4r4, but see lots of scrach mem

parent dab53610
...@@ -53,6 +53,7 @@ include_directories(BEFORE ...@@ -53,6 +53,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/half/include ${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility ${PROJECT_BINARY_DIR}/composable_kernel/include/utility
......
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
{
template <typename... Wei, typename... In, typename... Out>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides,
const MultiIndex<2> conv_dilations,
const MultiIndex<2> in_left_pads,
const MultiIndex<2> in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1];
// weight tensor
#if 0
// TODO implement graph optimization of tensor descriptor transformation
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_multi_index(C, Y, X)}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#else
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<2>(make_multi_index(K, C * Y * X)),
make_tuple(DynamicPassThrough{K}, DynamicPassThrough{C * Y * X}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#endif
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicLeftPad{Hi, InLeftPadH},
DynamicLeftPad{Wi, InLeftPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})),
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicRightPad{Hi + InLeftPadH, InRightPadH},
DynamicRightPad{Wi + InLeftPadW, InRightPadW}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(I2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(I3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicEmbed<2>{make_multi_index(Y, Ho),
make_multi_index(ConvDilationH, ConvStrideH)},
DynamicEmbed<2>{make_multi_index(X, Wo),
make_multi_index(ConvDilationW, ConvStrideW)}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(DynamicMerge<3>{make_multi_index(C, Y, X)},
DynamicMerge<3>{make_multi_index(N, Ho, Wo)}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
#if 0
//TODO: implement graph optimization of tensor descriptor transformation
const auto out_gemmm_gemmn_global_desc =
transform_dynamic_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_mult_index(N, Ho, Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<3>(make_multi_index(N, K, Ho * Wo)),
make_tuple(DynamicPassThrough{K}, DynamicMerge<2>{make_multi_index(N, Ho * Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#endif
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)},
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool is_even_number_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
const auto kernel_even =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>>;
const auto kernel_odd =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>>;
if(is_even_number_k_block_loop)
{
launch_kernel(kernel_even,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{});
}
else
{
launch_kernel(kernel_odd,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{});
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -49,7 +49,7 @@ template <index_t GridSize, ...@@ -49,7 +49,7 @@ template <index_t GridSize,
typename WeiBlockCopyDstAccessOrder, typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -43,7 +43,7 @@ template <index_t GridSize, ...@@ -43,7 +43,7 @@ template <index_t GridSize,
index_t GemmBBlockCopySrcDataPerRead_GemmN, index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN, index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1> index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
......
#ifndef CK_DYNAMIC_GRIDWISE_COL2IM_GEMMKGEMMN_NCHW_HPP #ifndef CK_GRIDWISE_DYNAMIC_COL2IM_GEMMKGEMMN_NCHW_HPP
#define CK_DYNAMIC_GRIDWISE_COL2IM_GEMMKGEMMN_NCHW_HPP #define CK_GRIDWISE_DYNAMIC_COL2IM_GEMMKGEMMN_NCHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
...@@ -94,7 +94,7 @@ template <index_t BlockSize, ...@@ -94,7 +94,7 @@ template <index_t BlockSize,
typename BlockCopySrcAccessOrder, typename BlockCopySrcAccessOrder,
typename BlockCopyDstAccessOrder, typename BlockCopyDstAccessOrder,
index_t BlockCopyDataPerAccess_GemmN> index_t BlockCopyDataPerAccess_GemmN>
struct DynamicGridwiseCol2Im_gemmkgemmn_nchw struct GridwiseDynamicCol2Im_gemmkgemmn_nchw
{ {
// this version has scratch memory issue, due to: // this version has scratch memory issue, due to:
// 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor // 1. ThreadwiseDynamicTensorSliceTransfer_v1r1 keeps reference to tensor descriptor
......
...@@ -6,19 +6,27 @@ ...@@ -6,19 +6,27 @@
namespace ck { namespace ck {
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
template <index_t N> template <index_t N>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_native_tensor_descriptor_packed(const MultiIndex<N>& lengths) make_dynamic_naive_tensor_descriptor(const MultiIndex<N>& lengths, const MultiIndex<N>& strides)
{ {
const auto transforms = make_tuple(DynamicEmbed<N>{lengths, strides});
const auto transforms = make_tuple(DynamicUnMerge<N>{lengths});
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss = constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const index_t element_space_size = index_t element_space_size = 1;
container_reduce(lengths, math::multiplies<index_t>{}, index_t{1});
static_for<0, N, 1>{}([&](auto i) { element_space_size += (lengths[i] - 1) * strides[i]; });
return DynamicTensorDescriptor<decltype(transforms), return DynamicTensorDescriptor<decltype(transforms),
decltype(low_dim_hidden_idss), decltype(low_dim_hidden_idss),
...@@ -29,17 +37,17 @@ make_dynamic_native_tensor_descriptor_packed(const MultiIndex<N>& lengths) ...@@ -29,17 +37,17 @@ make_dynamic_native_tensor_descriptor_packed(const MultiIndex<N>& lengths)
template <index_t N> template <index_t N>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_dynamic_native_tensor_descriptor(const MultiIndex<N>& lengths, const MultiIndex<N>& strides) make_dynamic_naive_tensor_descriptor_packed(const MultiIndex<N>& lengths)
{ {
const auto transforms = make_tuple(DynamicEmbed<N>{lengths, strides});
const auto transforms = make_tuple(DynamicUnMerge<N>{lengths});
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss = constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
index_t element_space_size = 1; const index_t element_space_size =
container_reduce(lengths, math::multiplies<index_t>{}, index_t{1});
static_for<0, N, 1>{}([&](auto i) { element_space_size += (lengths[i] - 1) * strides[i]; });
return DynamicTensorDescriptor<decltype(transforms), return DynamicTensorDescriptor<decltype(transforms),
decltype(low_dim_hidden_idss), decltype(low_dim_hidden_idss),
...@@ -48,5 +56,22 @@ make_dynamic_native_tensor_descriptor(const MultiIndex<N>& lengths, const MultiI ...@@ -48,5 +56,22 @@ make_dynamic_native_tensor_descriptor(const MultiIndex<N>& lengths, const MultiI
element_space_size}; element_space_size};
} }
template <index_t N>
__host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned(const MultiIndex<N>& lengths, index_t align)
{
auto strides = make_zero_multi_index<N>();
strides(Number<N - 1>{}) = 1;
strides(Number<N - 2>{}) = math::lcm(lengths[Number<N - 1>{}], align);
static_for<N - 3, -1, -1>{}([&](auto i) {
constexpr auto i_p1 = i + Number<1>{};
strides(i) = strides(i_p1) * lengths(i_p1);
});
return make_dynamic_naive_tensor_descriptor<N>(lengths, strides);
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
namespace ck { namespace ck {
template <typename Lengths, typename Strides> template <typename Lengths, typename Strides>
__host__ __device__ constexpr auto make_dynamic_native_tensor_descriptor_v1(const Lengths& lengths, __host__ __device__ constexpr auto make_dynamic_naive_tensor_descriptor_v1(const Lengths& lengths,
const Strides& strides) const Strides& strides)
{ {
static_assert(Lengths::Size() == Strides::Size(), "wrong! Size not the same"); static_assert(Lengths::Size() == Strides::Size(), "wrong! Size not the same");
......
...@@ -250,7 +250,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r1 ...@@ -250,7 +250,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r1
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ = static constexpr auto thread_buffer_desc_ =
make_dynamic_native_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{})); make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r1<BlockSrcDesc, using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r1<BlockSrcDesc,
decltype(thread_buffer_desc_), decltype(thread_buffer_desc_),
...@@ -425,7 +425,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r2 ...@@ -425,7 +425,7 @@ struct BlockwiseDynamicTensorSliceTransfer_v2r2
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
static constexpr auto thread_buffer_desc_ = static constexpr auto thread_buffer_desc_ =
make_dynamic_native_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{})); make_dynamic_naive_tensor_descriptor_packed<nDim>(to_multi_index(ThreadSliceLengths{}));
using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc, using ThreadwiseRead = ThreadwiseDynamicTensorSliceTransfer_v1r2<BlockSrcDesc,
decltype(thread_buffer_desc_), decltype(thread_buffer_desc_),
......
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
typename Float,
typename AccFloat,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct GridwiseDynamicGemm_km_kn_mn_v1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t K = a_k_m_global_desc.GetLength(I0);
const index_t M = a_k_m_global_desc.GetLength(I1);
const index_t N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const index_t m_block_work_num = M / MPerBlock;
const index_t n_block_work_num = N / NPerBlock;
const index_t m_block_work_id = get_block_1d_id() / n_block_work_num;
const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num;
const index_t m_block_data_on_global = m_block_work_id * MPerBlock;
const index_t n_block_data_on_global = n_block_work_id * NPerBlock;
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockTransferDstScalarPerVector_M,
BBlockTransferDstScalarPerVector_N,
MPerThread,
NPerThread);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, MPerBlock), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned<2>(
make_multi_index(KPerBlock, NPerBlock), max_lds_align);
// A matrix blockwise copy
auto a_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize,
Float,
Float,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(a_k_m_global_desc,
make_multi_index(0, m_block_data_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_block_copy =
BlockwiseDynamicTensorSliceTransfer_v2r2<BlockSize,
Float,
Float,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
AddressSpace::Global,
AddressSpace::Lds,
InMemoryDataOperation::Set,
1,
1>(b_k_n_global_desc,
make_multi_index(0, n_block_data_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t a_k_m_block_mtx_stride =
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr index_t b_k_n_block_mtx_stride =
b_k_n_block_desc.CalculateOffset(make_multi_index(1, 0)) -
b_k_n_block_desc.CalculateOffset(make_multi_index(0, 0));
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
const auto block_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr index_t b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// LDS double buffer: preload data into LDS
{
a_block_copy.Run(a_k_m_global_desc, p_a_global, a_k_m_block_desc, p_a_block_double);
b_block_copy.Run(b_k_n_global_desc, p_b_global, b_k_n_block_desc, p_b_block_double);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space_size;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space_size;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space_size : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space_size : p_b_block_double;
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
block_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_next);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_next);
}
}
// LDS double buffer: tail
{
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left
{
a_block_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_block_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads();
// LDS double buffer: load last data from device mem
a_block_copy.RunRead(a_k_m_global_desc, p_a_global);
b_block_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_block_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_block_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
block_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// output: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed<4>(
make_multi_index(MRepeat, MPerThread, NRepeat, NPerThread));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
block_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseDynamicTensorSliceTransfer_v1r2<
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
1,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
1>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run(c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
}
}
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global,
integral_constant<bool, IsEvenNumberKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
__shared__ Float p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>{});
}
};
} // namespace ck
#endif
...@@ -15,7 +15,7 @@ template <typename SrcData, ...@@ -15,7 +15,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename SrcDstDimAccessOrder, typename SrcDstDimAccessOrder,
index_t SrcDstVectorAccessDim, index_t SrcDstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
...@@ -128,7 +128,7 @@ template <typename SrcData, ...@@ -128,7 +128,7 @@ template <typename SrcData,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename SrcDstDimAccessOrder, typename SrcDstDimAccessOrder,
index_t SrcDstVectorAccessDim, index_t SrcDstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
...@@ -144,108 +144,285 @@ threadwise_dynamic_tensor_slice_transfer_v1r2(const SrcDesc& src_desc, ...@@ -144,108 +144,285 @@ threadwise_dynamic_tensor_slice_transfer_v1r2(const SrcDesc& src_desc,
DynamicTensorCoordinate_t<DstDesc>& dst_coord, DynamicTensorCoordinate_t<DstDesc>& dst_coord,
DstData* p_dst) DstData* p_dst)
{ {
// TODO use constexpr for coordinate-step to make sure compiler behave correctly static_assert(remove_reference_t<SrcDesc>::GetNumOfDimension() ==
const auto src_step_0_p1 = remove_reference_t<DstDesc>::GetNumOfDimension(),
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 1)); "inconsistent # of dimension");
const auto src_step_0_m1 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, -1));
const auto src_step_p1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(1, 0));
const auto src_step_m1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-1, 0));
const auto dst_step_0_p1 = if constexpr(remove_reference_t<SrcDesc>::GetNumOfDimension() == 2)
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 1)); {
const auto dst_step_0_m1 = // TODO use constexpr for coordinate-step to make sure compiler behave correctly
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, -1)); const auto src_step_0_p1 =
const auto dst_step_p1_0 = make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 1));
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(1, 0)); const auto src_step_0_m1 =
const auto dst_step_m1_0 = make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, -1));
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-1, 0));
constexpr index_t J0 = SliceLengths{}[0]; const auto src_step_p1_0 =
constexpr index_t J1 = SliceLengths{}[1]; make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(1, 0));
const auto src_step_m1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-1, 0));
bool forward_dim0 = true; const auto dst_step_0_p1 =
bool forward_dim1 = true; make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 1));
const auto dst_step_0_m1 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, -1));
const auto dst_step_p1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(1, 0));
const auto dst_step_m1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-1, 0));
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
bool forward_dim0 = true;
bool forward_dim1 = true;
// hardcoded for 2d loop for now
#pragma unroll
for(index_t j0 = 0; j0 < J0; ++j0)
{
#pragma unroll #pragma unroll
for(index_t j1 = 0; j1 < J1; ++j1) for(index_t j0 = 0; j0 < Len0; ++j0)
{ {
// do work #pragma unroll
transfer_data<SrcData, for(index_t j1 = 0; j1 < Len1; ++j1)
1, {
SrcAddressSpace, // do work
DstAddressSpace, transfer_data<SrcData,
DstInMemOp, 1,
SrcScalarStrideInVector, SrcAddressSpace,
DstScalarStrideInVector>( DstAddressSpace,
p_src, DstInMemOp,
src_coord.GetOffset(), SrcScalarStrideInVector,
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord), DstScalarStrideInVector>(
src_desc.GetElementSpaceSize(), p_src,
p_dst, src_coord.GetOffset(),
dst_coord.GetOffset(), coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord), src_coord),
dst_desc.GetElementSpaceSize()); src_desc.GetElementSpaceSize(),
p_dst,
dst_coord.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc,
dst_coord),
dst_desc.GetElementSpaceSize());
// move dim1 iterator
if(j1 < Len1 - 1)
{
if(forward_dim1)
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_p1);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_p1);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_m1);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_m1);
}
}
}
// move dim1 iterator // switch dim1 iteration direction
if(j1 < J1 - 1) forward_dim1 = !forward_dim1;
// move dim0 iterator
if(j0 < Len0 - 1)
{ {
if(forward_dim1) if(forward_dim0)
{ {
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_p1); move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_p1); move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_p1_0);
} }
else else
{ {
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_m1); move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_m1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_m1); move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_m1_0);
} }
} }
} }
// switch dim1 iteration direction // move src and dst coordinate back to their origins
forward_dim1 = !forward_dim1; // move src and dst coordinate back to their origins
constexpr index_t loc0 = Len0 - 1;
constexpr index_t loc1 = Len0 % 2 == 0 ? 0 : Len1 - 1;
// move dim0 iterator const auto src_step_back =
if(j0 < J0 - 1) make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-loc0, -loc1));
const auto dst_step_back =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-loc0, -loc1));
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_back);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_back);
}
else if constexpr(remove_reference_t<SrcDesc>::GetNumOfDimension() == 4)
{
// TODO use constexpr for coordinate-step to make sure compiler behave correctly
const auto src_step_0_0_0_p1 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 0, 0, 1));
const auto src_step_0_0_0_m1 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 0, 0, -1));
const auto src_step_0_0_p1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 0, 1, 0));
const auto src_step_0_0_m1_0 =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(0, 0, -1, 0));
const auto src_step_0_p1_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 1, 0, 0));
const auto src_step_0_m1_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, -1, 0, 0));
const auto src_step_p1_0_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(1, 0, 0, 0));
const auto src_step_m1_0_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-1, 0, 0, 0));
const auto dst_step_0_0_0_p1 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 0, 0, 1));
const auto dst_step_0_0_0_m1 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 0, 0, -1));
const auto dst_step_0_0_p1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 0, 1, 0));
const auto dst_step_0_0_m1_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 0, -1, 0));
const auto dst_step_0_p1_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, 1, 0, 0));
const auto dst_step_0_m1_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(0, -1, 0, 0));
const auto dst_step_p1_0_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(1, 0, 0, 0));
const auto dst_step_m1_0_0_0 =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-1, 0, 0, 0));
constexpr index_t Len0 = SliceLengths{}[0];
constexpr index_t Len1 = SliceLengths{}[1];
constexpr index_t Len2 = SliceLengths{}[2];
constexpr index_t Len3 = SliceLengths{}[3];
bool forward_dim0 = true;
bool forward_dim1 = true;
bool forward_dim2 = true;
bool forward_dim3 = true;
#pragma unroll
for(index_t j0 = 0; j0 < Len0; ++j0)
{ {
if(forward_dim0) #pragma unroll
for(index_t j1 = 0; j1 < Len1; ++j1)
{ {
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_p1_0); #pragma unroll
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_p1_0); for(index_t j2 = 0; j2 < Len2; ++j2)
{
#pragma unroll
for(index_t j3 = 0; j3 < Len3; ++j3)
{
// do work
transfer_data<SrcData,
1,
SrcAddressSpace,
DstAddressSpace,
DstInMemOp,
SrcScalarStrideInVector,
DstScalarStrideInVector>(
p_src,
src_coord.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
src_coord),
src_desc.GetElementSpaceSize(),
p_dst,
dst_coord.GetOffset(),
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc,
dst_coord),
dst_desc.GetElementSpaceSize());
// move dim1 iterator
if(j3 < Len3 - 1)
{
if(forward_dim3)
{
move_dynamic_tensor_coordinate(
src_desc, src_coord, src_step_0_0_0_p1);
move_dynamic_tensor_coordinate(
dst_desc, dst_coord, dst_step_0_0_0_p1);
}
else
{
move_dynamic_tensor_coordinate(
src_desc, src_coord, src_step_0_0_0_m1);
move_dynamic_tensor_coordinate(
dst_desc, dst_coord, dst_step_0_0_0_m1);
}
}
}
// switch dim3 iteration direction
forward_dim3 = !forward_dim3;
// move dim1 iterator
if(j2 < Len2 - 1)
{
if(forward_dim2)
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_0_p1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_0_p1_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_0_m1_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_0_m1_0);
}
}
}
// switch dim2 iteration direction
forward_dim2 = !forward_dim2;
// move dim1 iterator
if(j1 < Len1 - 1)
{
if(forward_dim1)
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_p1_0_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_p1_0_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_0_m1_0_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_0_m1_0_0);
}
}
} }
else
// switch dim1 iteration direction
forward_dim1 = !forward_dim1;
// move dim0 iterator
if(j0 < Len0 - 1)
{ {
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_m1_0); if(forward_dim0)
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_m1_0); {
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_p1_0_0_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_p1_0_0_0);
}
else
{
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_m1_0_0_0);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_m1_0_0_0);
}
} }
} }
}
// move src and dst coordinate back to their origins // move src and dst coordinate back to their origins
// hardcoded for 2d loop constexpr index_t loc0 = Len0 - 1;
if constexpr(J0 % 2 == 0) constexpr index_t loc1 = Len0 % 2 == 0 ? 0 : Len1 - 1;
{ constexpr index_t loc2 = Len1 % 2 == 0 ? 0 : Len2 - 1;
const auto src_step_back = constexpr index_t loc3 = Len2 % 2 == 0 ? 0 : Len3 - 1;
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-(J0 - 1), 0));
const auto dst_step_back =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-(J0 - 1), 0));
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_back); const auto src_step_back = make_dynamic_tensor_coordinate_step(
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_back); src_desc, make_multi_index(-loc0, -loc1, -loc2, -loc3));
}
else const auto dst_step_back = make_dynamic_tensor_coordinate_step(
{ dst_desc, make_multi_index(-loc0, -loc1, -loc2, -loc3));
const auto src_step_back =
make_dynamic_tensor_coordinate_step(src_desc, make_multi_index(-(J0 - 1), -(J1 - 1)));
const auto dst_step_back =
make_dynamic_tensor_coordinate_step(dst_desc, make_multi_index(-(J0 - 1), -(J1 - 1)));
move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_back); move_dynamic_tensor_coordinate(src_desc, src_coord, src_step_back);
move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_back); move_dynamic_tensor_coordinate(dst_desc, dst_coord, dst_step_back);
...@@ -259,7 +436,7 @@ template <typename SrcDesc, ...@@ -259,7 +436,7 @@ template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename SrcDstDimAccessOrder, typename SrcDstDimAccessOrder,
index_t SrcDstVectorAccessDim, index_t SrcDstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
...@@ -304,7 +481,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r1 ...@@ -304,7 +481,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r1
DstDesc, DstDesc,
SliceLengths, SliceLengths,
SrcDstDimAccessOrder, SrcDstDimAccessOrder,
SrcDstVectorAccessDim, SrcDstVectorDim,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
SrcAddressSpace, SrcAddressSpace,
...@@ -358,7 +535,7 @@ template <typename SrcDesc, ...@@ -358,7 +535,7 @@ template <typename SrcDesc,
typename DstDesc, typename DstDesc,
typename SliceLengths, typename SliceLengths,
typename SrcDstDimAccessOrder, typename SrcDstDimAccessOrder,
index_t SrcDstVectorAccessDim, index_t SrcDstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
AddressSpace SrcAddressSpace, AddressSpace SrcAddressSpace,
...@@ -402,7 +579,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2 ...@@ -402,7 +579,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r2
DstDesc, DstDesc,
SliceLengths, SliceLengths,
SrcDstDimAccessOrder, SrcDstDimAccessOrder,
SrcDstVectorAccessDim, SrcDstVectorDim,
SrcScalarPerVector, SrcScalarPerVector,
DstScalarPerVector, DstScalarPerVector,
SrcAddressSpace, SrcAddressSpace,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" #include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template <typename T, template <typename T,
typename InDesc, typename InDesc,
...@@ -13,17 +13,17 @@ template <typename T, ...@@ -13,17 +13,17 @@ template <typename T,
typename ConvDilations, typename ConvDilations,
typename LeftPads, typename LeftPads,
typename RightPads> typename RightPads>
void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
LeftPads, LeftPads,
RightPads, RightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -770,45 +770,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -770,45 +770,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< using gridwise_conv =
GridSize, GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
BlockSize, GridSize,
T, BlockSize,
T, T,
decltype(in_nchw_desc), T,
decltype(wei_kcyx_desc), decltype(in_nchw_desc),
decltype(out_nkhw_desc), decltype(wei_kcyx_desc),
ConvStrides, decltype(out_nkhw_desc),
ConvDilations, ConvStrides,
LeftPads, ConvDilations,
RightPads, LeftPads,
BPerBlock, RightPads,
KPerBlock, BPerBlock,
EPerBlock, KPerBlock,
GemmNRepeat, EPerBlock,
GemmMPerThread, GemmNRepeat,
GemmNPerThread, GemmMPerThread,
GemmKPerThread, GemmNPerThread,
GemmMLevel0Cluster, GemmKPerThread,
GemmNLevel0Cluster, GemmMLevel0Cluster,
GemmMLevel1Cluster, GemmNLevel0Cluster,
GemmNLevel1Cluster, GemmMLevel1Cluster,
GemmDataPerReadA, GemmNLevel1Cluster,
GemmDataPerReadB, GemmDataPerReadA,
InBlockCopySubLengths_E_N1_B_N2, GemmDataPerReadB,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopySrcAccessOrder, InBlockCopyThreadClusterArrangeOrder,
InBlockCopyDstAccessOrder, InBlockCopySrcAccessOrder,
InBlockCopySrcDataPerRead_B, InBlockCopyDstAccessOrder,
InBlockCopyDstDataPerWrite_N2, InBlockCopySrcDataPerRead_B,
WeiBlockCopySubLengths_E_K, InBlockCopyDstDataPerWrite_N2,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySubLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyClusterLengths_E_K,
WeiBlockCopySrcAccessOrder, WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopyDstAccessOrder, WeiBlockCopySrcAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstAccessOrder,
WeiBlockCopyDstDataPerWrite_K>; WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>;
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class T,
class InDesc, class InDesc,
...@@ -12,17 +12,17 @@ template <class T, ...@@ -12,17 +12,17 @@ template <class T,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
InRightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -153,7 +153,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -153,7 +153,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
// vector 4 // vector 4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -187,7 +187,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -187,7 +187,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1 #elif 0
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -289,6 +289,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -289,6 +289,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 0 #elif 0
// cdata = 64, BlockSize = 128, 128x64x4 // cdata = 64, BlockSize = 128, 128x64x4
...@@ -968,7 +1003,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -968,7 +1003,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw<
GridSize, GridSize,
BlockSize, BlockSize,
TDevice, TDevice,
......
...@@ -28,11 +28,11 @@ void device_dummy_dynamic_transform(InDesc, ...@@ -28,11 +28,11 @@ void device_dummy_dynamic_transform(InDesc,
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
const auto in_nchw_desc = make_dynamic_native_tensor_descriptor<4>( const auto in_nchw_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides())); to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides()));
const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor<4>( const auto wei_kcyx_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides())); to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides()));
const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor<4>( const auto out_nkhw_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides())); to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides()));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
......
...@@ -28,11 +28,11 @@ void device_dummy_dynamic_transform_v1(InDesc, ...@@ -28,11 +28,11 @@ void device_dummy_dynamic_transform_v1(InDesc,
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
const auto in_nchw_desc = make_dynamic_native_tensor_descriptor_v1( const auto in_nchw_desc = make_dynamic_naive_tensor_descriptor_v1(
to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides())); to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides()));
const auto wei_kcyx_desc = make_dynamic_native_tensor_descriptor_v1( const auto wei_kcyx_desc = make_dynamic_naive_tensor_descriptor_v1(
to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides())); to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides()));
const auto out_nkhw_desc = make_dynamic_native_tensor_descriptor_v1( const auto out_nkhw_desc = make_dynamic_naive_tensor_descriptor_v1(
to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides())); to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides()));
const auto conv_strides = to_multi_index(ConvStrides{}); const auto conv_strides = to_multi_index(ConvStrides{});
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp" #include "gridwise_operation_wrapper.hpp"
#include "dynamic_gridwise_col2im_gemmkgemmn_nchw.hpp" #include "gridwise_dynamic_col2im_gemmkgemmn_nchw.hpp"
template <typename T, template <typename T,
typename ColDesc, typename ColDesc,
...@@ -40,10 +40,10 @@ void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc, ...@@ -40,10 +40,10 @@ void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc,
col_gemmk_gemmn_device_buf.ToDevice(col_gemmk_gemmn.mData.data()); col_gemmk_gemmn_device_buf.ToDevice(col_gemmk_gemmn.mData.data());
img_n_c_hi_wi_device_buf.ToDevice(img_n_c_hi_wi.mData.data()); img_n_c_hi_wi_device_buf.ToDevice(img_n_c_hi_wi.mData.data());
const auto col_gemmk_gemmn_desc = make_dynamic_native_tensor_descriptor<2>( const auto col_gemmk_gemmn_desc = make_dynamic_naive_tensor_descriptor<2>(
to_multi_index(ColDesc::GetLengths()), to_multi_index(ColDesc::GetStrides())); to_multi_index(ColDesc::GetLengths()), to_multi_index(ColDesc::GetStrides()));
const auto img_n_c_hi_wi_desc = make_dynamic_native_tensor_descriptor<4>( const auto img_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(ImgDesc::GetLengths()), to_multi_index(ImgDesc::GetStrides())); to_multi_index(ImgDesc::GetLengths()), to_multi_index(ImgDesc::GetStrides()));
const auto filter_sizes = to_multi_index(FilterSizes{}); const auto filter_sizes = to_multi_index(FilterSizes{});
...@@ -83,7 +83,7 @@ void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc, ...@@ -83,7 +83,7 @@ void device_dynamic_col2im_gemmkgemmn_nchw(ColDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_col2im = constexpr auto gridwise_col2im =
DynamicGridwiseCol2Im_gemmkgemmn_nchw<BlockSize, GridwiseDynamicCol2Im_gemmkgemmn_nchw<BlockSize,
GemmKPerBlock, GemmKPerBlock,
GemmNPerBlock, GemmNPerBlock,
BlockCopySubLengths_GemmK_GemmN, BlockCopySubLengths_GemmK_GemmN,
......
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations,
class InLeftPads,
class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(InDesc::GetLengths()), to_multi_index(InDesc::GetStrides()));
const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(WeiDesc::GetLengths()), to_multi_index(WeiDesc::GetStrides()));
const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor<4>(
to_multi_index(OutDesc::GetLengths()), to_multi_index(OutDesc::GetStrides()));
const auto conv_strides = to_multi_index(ConvStrides{});
const auto conv_dilations = to_multi_index(ConvDilations{});
const auto in_left_pads = to_multi_index(InLeftPads{});
const auto in_right_pads = to_multi_index(InRightPads{});
#if 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const index_t N = out_n_k_ho_wo_desc.GetLength(I0);
const index_t K = out_n_k_ho_wo_desc.GetLength(I1);
const index_t Ho = out_n_k_ho_wo_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_desc.GetLength(I3);
const index_t C = wei_k_c_y_x_desc.GetLength(I1);
const index_t Y = wei_k_c_y_x_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_desc.GetLength(I3);
const index_t GemmM = K;
const index_t GemmN = N * Ho * Wo;
const index_t GemmK = C * Y * X;
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto conv_driver = DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw<
BlockSize,
TDevice,
TDevice,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
for(index_t j = 0; j < nrepeat; ++j)
{
conv_driver.Run(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
...@@ -11,8 +11,9 @@ ...@@ -11,8 +11,9 @@
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv.hpp" #include "host_conv.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp" #include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp" #include "device_dummy_dynamic_transform_v1.hpp"
#include "device_dummy_dynamic_transform.hpp" #include "device_dummy_dynamic_transform.hpp"
...@@ -22,6 +23,21 @@ int main(int argc, char* argv[]) ...@@ -22,6 +23,21 @@ int main(int argc, char* argv[])
using namespace ck; using namespace ck;
#if 0 #if 0
// 1x1, 8x8
constexpr index_t N = 2;
constexpr index_t C = 24;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -550,7 +566,7 @@ int main(int argc, char* argv[]) ...@@ -550,7 +566,7 @@ int main(int argc, char* argv[])
} }
#if 0 #if 0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
...@@ -562,17 +578,29 @@ int main(int argc, char* argv[]) ...@@ -562,17 +578,29 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#elif 0 #elif 0
device_dummy_static_transform(in_nchw_desc, device_dummy_static_transform(in_nchw_desc,
in_nchw, in_nchw,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment