"vscode:/vscode.git/clone" did not exist on "c43c30c01b8a99005e1f235eb9e61d81f4f603b5"
Unverified Commit 8f5f6496 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

backward data (#7)

* enabled atomic add in tensor copy
* added gridwise GEMM
* added backward data conv using GEMM + atomic
* added backward data conv using GEMM, no atomic
parent 31ded4ac
...@@ -55,9 +55,11 @@ include_directories(BEFORE ...@@ -55,9 +55,11 @@ include_directories(BEFORE
if(DEVICE_BACKEND STREQUAL "AMD") if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
elseif(DEVICE_BACKEND STREQUAL "NVIDIA") elseif(DEVICE_BACKEND STREQUAL "NVIDIA")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp")
endif() endif()
add_subdirectory(driver) add_subdirectory(driver)
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template <typename GridwiseOp, typename... Xs>
__global__ void run_gridwise_operation(GridwiseOp, Xs... xs)
{
GridwiseOp{}.Run(xs...);
}
#endif
#ifndef CK_GRIDWISE_COL2IM_EB_NCHW_HPP
#define CK_GRIDWISE_COL2IM_EB_NCHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
namespace ck {
// B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename ColGlobalDesc,
typename ImgGlobalDesc,
typename FilterSizes,
typename OutputSizes,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
typename BlockCopySubLengths_E_B,
typename BlockCopyClusterLengths_E_B,
typename BlockCopyThreadClusterArrangeOrder,
typename BlockCopySrcAccessOrder,
typename BlockCopyDstAccessOrder,
index_t BlockCopyDataPerAccess_B>
struct GridwiseCol2Im_eb_nchw
{
__device__ void Run(const Float* const __restrict__ p_col_global,
Float* const __restrict__ p_img_global) const
{
constexpr auto col_e_b_global_desc = ColGlobalDesc{};
constexpr auto img_n_c_hi_wi_global_desc = ImgGlobalDesc{};
constexpr index_t N = img_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = img_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = img_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = img_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t Ho = OutputSizes{}[0];
constexpr index_t Wo = OutputSizes{}[1];
constexpr index_t Y = FilterSizes{}[0];
constexpr index_t X = FilterSizes{}[1];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || BlockCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % BlockCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [E, B]
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t EBlockWork = E / EPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// construct img_eb_global_desc
constexpr auto img_n_c_hip_wip_global_desc = transform_tensor_descriptor(
img_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto img_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
img_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto img_e_b_global_desc = transform_tensor_descriptor(
img_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// blockwise atomic accumulation
auto blockwise_copy = BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(col_e_b_global_desc),
decltype(img_e_b_global_desc),
Sequence<EPerBlock, BPerBlock>,
BlockCopySubLengths_E_B,
BlockCopyClusterLengths_E_B,
BlockCopyThreadClusterArrangeOrder,
BlockCopySrcAccessOrder,
BlockCopyDstAccessOrder,
1,
1,
BlockCopyDataPerAccess_B,
BlockCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::atomic_add>(
{e_block_data_on_global, b_block_data_on_global},
{e_block_data_on_global, b_block_data_on_global});
// blockwise copy
blockwise_copy.Run(p_col_global, p_img_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename WeiBlockCopySubLengths_K_E,
typename WeiBlockCopyClusterLengths_K_E,
index_t WeiBlockCopyDataPerAccess_E,
typename OutBlockCopySubLengths_K_B,
typename OutBlockCopyClusterLengths_K_B,
index_t OutBlockCopyDataPerAccess_B,
index_t InThreadCopyDataPerAccess_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t E = C * Y * X;
constexpr index_t B = N * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDataPerAccess_B == 1)) &&
(X == 1 || ConvDilationW % InThreadCopyDataPerAccess_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// output tensor
constexpr auto out_n_k_howo_global_desc =
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_howo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// weight tensor
constexpr auto wei_k_e_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM: atomic add
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_k_e_global_desc),
decltype(out_k_b_global_desc),
decltype(in_e_b_global_desc),
InMemoryDataOperation::atomic_add,
EPerBlock,
BPerBlock,
KPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
WeiBlockCopySubLengths_K_E,
WeiBlockCopyClusterLengths_K_E,
WeiBlockCopyDataPerAccess_E,
OutBlockCopySubLengths_K_B,
OutBlockCopyClusterLengths_K_B,
OutBlockCopyDataPerAccess_B,
InThreadCopyDataPerAccess_B>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t EPerBlock,
index_t BPerBlock,
index_t KPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmDataPerReadA,
index_t GemmDataPerReadB,
typename OutBlockCopySubLengths_K_B_N0,
typename OutBlockCopyClusterLengths_K_B_N0,
index_t OutBlockCopySrcDataPerRead_B,
index_t OutBlockCopyDstDataPerWrite_N0,
typename WeiBlockCopySubLengths_K_E_C0,
typename WeiBlockCopyClusterLengths_K_E_C0,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_C0,
index_t InThreadCopyDstDataPerWrite_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
const Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t C0 = GemmMPerThreadSubC;
constexpr index_t N0 = GemmNPerThreadSubC;
static_assert(C % C0 == 0 && N % N0 == 0, "wrong!");
constexpr index_t C1 = C / C0;
constexpr index_t N1 = N / N0;
constexpr index_t E = C1 * Y * X;
constexpr index_t B = N1 * Ho * Wo;
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDstDataPerWrite_B == 1)) &&
(X == 1 || ConvDilationW % InThreadCopyDstDataPerWrite_B == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// divide block work by [K, B]
static_assert(E % EPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t EBlockWork = E / EPerBlock;
constexpr index_t BBlockWork = B / BPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<EBlockWork, BBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
// output tensor
// global tensor in global memory, src of blockwise copy
constexpr auto out_n_k_howo_global_desc =
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
constexpr auto out_n0_n1_k_howo_global_desc = transform_tensor_descriptor(
out_n_k_howo_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{}, PassThrough<K>{}, PassThrough<Ho * Wo>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto out_k_b_n0_global_desc = transform_tensor_descriptor(
out_n0_n1_k_howo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N1, Ho * Wo>>{}, PassThrough<N0>{}),
make_tuple(Sequence<2>{}, Sequence<1, 3>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto out_k_b_n0_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, BPerBlock, N0>{}, Number<OutBlockCopyDstDataPerWrite_N0>{});
// output tensor blockwise copy
auto blockwise_out_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(out_k_b_n0_global_desc),
decltype(out_k_b_n0_block_desc),
decltype(out_k_b_n0_block_desc.GetLengths()),
OutBlockCopySubLengths_K_B_N0,
OutBlockCopyClusterLengths_K_B_N0,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
1,
2,
OutBlockCopySrcDataPerRead_B,
OutBlockCopyDstDataPerWrite_N0,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global, 0}, {0, 0, 0});
// weight tensor
// global tensor in global memory, src of blockwise copy
constexpr auto wei_k_cyx_global_desc =
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3);
constexpr auto wei_k_c0_e_global_desc =
transform_tensor_descriptor(wei_k_cyx_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<C0, E>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_k_e_c0_global_desc = reorder_tensor_descriptor_given_lower2upper(
wei_k_c0_e_global_desc, Sequence<0, 2, 1>{});
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto wei_k_e_c0_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, EPerBlock, C0>{}, Number<WeiBlockCopyDstDataPerWrite_C0>{});
// weight tensor blockwise copy
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(wei_k_e_c0_global_desc),
decltype(wei_k_e_c0_block_desc),
decltype(wei_k_e_c0_block_desc.GetLengths()),
WeiBlockCopySubLengths_K_E_C0,
WeiBlockCopyClusterLengths_K_E_C0,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
Sequence<0, 1, 2>,
1,
2,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_C0,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, e_block_data_on_global, 0}, {0, 0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, EPerBlock*C0] is in LDS
// b_mtx[KPerBlocl, BPerBlock*N0] is in LDS
// c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in
// register
constexpr auto a_k_ec0_block_mtx_desc = make_ConstantMatrixDescriptor(
wei_k_e_c0_block_desc.GetLength(I0),
wei_k_e_c0_block_desc.GetLength(I1) * wei_k_e_c0_block_desc.GetLength(I2),
wei_k_e_c0_block_desc.GetStride(I0));
constexpr auto b_k_bn0_block_mtx_desc = make_ConstantMatrixDescriptor(
out_k_b_n0_block_desc.GetLength(I0),
out_k_b_n0_block_desc.GetLength(I1) * out_k_b_n0_block_desc.GetLength(I2),
out_k_b_n0_block_desc.GetStride(I0));
// sanity check alignment
// TODO: this check is ad-hoc, should enforce it by enforcing alignment of
// wei_k_e_c0_block_desc and out_k_b_n0_block_desc
static_assert(a_k_ec0_block_mtx_desc.RowStride() % GemmDataPerReadB == 0, "wrong!");
static_assert(b_k_bn0_block_mtx_desc.RowStride() % GemmDataPerReadA == 0, "wrong!");
// sanity check
static_assert(EPerBlock % (GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 &&
BPerBlock % (GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = EPerBlock / (GemmMLevel0Cluster * GemmMLevel1Cluster);
constexpr index_t GemmNRepeat = BPerBlock / (GemmNLevel0Cluster * GemmNLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_ec0_block_mtx_desc),
decltype(b_k_bn0_block_mtx_desc),
decltype(c_e0e1c0_b0b1n0_thread_mtx_desc),
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB>{};
// LDS allocation for input and weight: be careful of alignment
constexpr index_t max_lds_align = math::lcm(WeiBlockCopyDstDataPerWrite_C0,
OutBlockCopyDstDataPerWrite_N0,
GemmDataPerReadA,
GemmDataPerReadB);
constexpr index_t out_block_space =
math::integer_least_multiple(out_k_b_n0_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t wei_block_space =
math::integer_least_multiple(wei_k_e_c0_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_out_block_double[2 * out_block_space];
__shared__ Float p_wei_block_double[2 * wei_block_space];
// register allocation for output
AccFloat p_in_thread[c_e0e1c0_b0b1n0_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_e0e1c0_b0b1n0_thread_mtx_desc, p_in_thread);
// LDS double buffer: preload data into LDS
{
blockwise_out_copy.Run(p_out_global, p_out_block_double);
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_out_block_now =
even_loop ? p_out_block_double : p_out_block_double + out_block_space;
Float* p_wei_block_now =
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
Float* p_out_block_next =
even_loop ? p_out_block_double + out_block_space : p_out_block_double;
Float* p_wei_block_next =
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_out_block_now, p_in_thread);
// LDS double buffer: store next data to LDS
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, p_out_block_next);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()];
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
blockwise_out_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
// LDS double buffer: store last data to LDS
blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer,
p_out_block_double + out_block_space);
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
p_wei_block_double + wei_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
p_out_block_double + out_block_space,
p_in_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread);
}
}
// input: register to global memory, atomic add
{
constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t E0 = E / E1;
constexpr index_t B1 = GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t B0 = B / B1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, 1, GemmMPerThreadSubC, GemmNRepeat, 1, GemmNPerThreadSubC>{});
// global input tensor, dst of threadwise copy
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n0_n1_c0_c1_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
constexpr auto in_e_c0_b_n0_global_desc = transform_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_global_desc,
make_tuple(Merge<Sequence<C1, Y, X>>{},
PassThrough<C0>{},
Merge<Sequence<N1, Ho, Wo>>{},
PassThrough<N0>{}),
make_tuple(Sequence<3, 4, 6>{}, Sequence<2>{}, Sequence<1, 5, 7>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto in_e0_e1_c0_b0_b1_n0_global_desc = transform_tensor_descriptor(
in_e_c0_b_n0_global_desc,
make_tuple(UnMerge<Sequence<E0, E1>>{},
PassThrough<C0>{},
UnMerge<Sequence<B0, B1>>{},
PassThrough<N0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t e_thread_data_on_global =
e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThreadSubC;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThreadSubC;
ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc),
decltype(in_e0_e1_c0_b0_b1_n0_global_desc),
decltype(in_e0_e1_c0_b0_b1_n0_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3, 4, 5>,
4,
1,
InThreadCopyDstDataPerWrite_B,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::atomic_add>({0, 0, 0, 0, 0, 0},
{e_thread_data_on_global / E1,
e_thread_data_on_global % E1,
0,
b_thread_data_on_global / B1,
b_thread_data_on_global % B1,
0})
.Run(p_in_thread, p_in_global);
}
}
};
} // namespace ck
#endif
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V2R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmK = K * Ydot * Xdot;
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename LeftPads,
typename RightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThreadSubC,
index_t GemmNPerThreadSubC,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t GemmKPerThreadLoop,
index_t GemmThreadGemmDataPerReadM,
index_t GemmThreadGemmDataPerReadN,
typename GemmABlockCopySubLengths, // Gemm-K, Gemm-M
typename GemmABlockCopyClusterLengths, // Gemm-K, Gemm-M
index_t GemmABlockCopyDataPerAccess, // Gemm-M
typename GemmBBlockCopySubLengths, // Gemm-K, Gemm-N
typename GemmBBlockCopyClusterLengths, // Gemm-K, Gemm-N
index_t GemmBBlockCopyDataPerAccess, // Gemm-N
index_t GemmCThreadCopyDataPerAccess // Gemm-N
>
struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
// sanity-check for vectorized memory load
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmCThreadCopyDataPerAccess == 1)) &&
(X == 1 || ConvDilationW % GemmCThreadCopyDataPerAccess == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
constexpr index_t Htilda = Ho + right_pad_ho;
constexpr index_t Wtilda = Wo + right_pad_wo;
// weight tensor
constexpr auto wei_k_c_yp_xp_global_desc = transform_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Pad<Sequence<Y, X>,
Sequence<0, 0>,
Sequence<Ydot * Ytilda - Y, Xdot * Xtilda - X>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor(
wei_k_c_yp_xp_global_desc,
make_tuple(PassThrough<K>{},
PassThrough<C>{},
Embed<Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<C, Ytilda, Xtilda>>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_n_k_hop_wop_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Pad<Sequence<Ho, Wo>,
Sequence<0, 0>,
Sequence<right_pad_ho, right_pad_wo>,
true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor(
out_n_k_hop_wop_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<K>{},
Embed<Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Embed<Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(
out_n_k_ydot_htilda_xdot_wtilda_global_desc,
make_tuple(Merge<Sequence<K, Ydot, Xdot>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// input tensor
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Pad<Sequence<Hi, Wi>, LeftPads, RightPads, true>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Ytilda, Htilda>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<Xtilda, Wtilda>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc,
make_tuple(Merge<Sequence<C, Ytilda, Xtilda>>{}, Merge<Sequence<N, Htilda, Wtilda>>{}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1r1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(out_gemmk_gemmn_global_desc),
decltype(in_gemmm_gemmn_global_desc),
InMemoryDataOperation::none,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopySubLengths,
GemmABlockCopyClusterLengths,
GemmABlockCopyDataPerAccess,
GemmBBlockCopySubLengths,
GemmBBlockCopyClusterLengths,
GemmBBlockCopyDataPerAccess,
GemmCThreadCopyDataPerAccess>{};
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
};
} // namespace ck
#endif
...@@ -93,7 +93,8 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw ...@@ -93,7 +93,8 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides()); Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor( constexpr auto out_nkhw_thread_desc =
get_convolution_output_default_4d_tensor_descriptor_deprecated(
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc); in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
// register // register
......
...@@ -107,11 +107,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -107,11 +107,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
static_assert(ConvDirection == ConvolutionDirection::Forward || static_assert(ConvDirection == ConvolutionDirection::Forward ||
ConvDirection == ConvolutionDirection::BackwardWeight, ConvDirection == ConvolutionDirection::BackwardWeight,
"wrong! this kernel only support convolution forward and backward-weight"); "wrong! this kernel only support convolution forward and backward-weight");
...@@ -130,17 +125,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -130,17 +125,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0); constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0];
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1); constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1];
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2); constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2];
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3); constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3];
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1); constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1];
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2];
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3];
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1]; constexpr index_t ConvStrideW = ConvStrides{}[1];
...@@ -230,7 +225,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -230,7 +225,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
2, 2,
3, 3,
InBlockCopySrcDataPerRead_B, InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2>( InBlockCopyDstDataPerWrite_N2,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
// weight tensor // weight tensor
...@@ -266,7 +265,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -266,7 +265,11 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -334,10 +337,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -334,10 +337,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run( blockwise_in_copy.Run(p_in_global, p_in_block_double);
p_in_global, p_in_block_double, global_address_space, generic_address_space); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -368,10 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -368,10 +369,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -397,10 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -397,10 +396,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
...@@ -474,20 +471,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -474,20 +471,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
const index_t b_thread_data_on_global = const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2; b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc), ThreadwiseGenericTensorSliceCopy_v4r2<
decltype(out_k0_k1_n1_b_n2_thread_desc),
decltype(out_k0_k1_n1_b_n2_global_desc), decltype(out_k0_k1_n1_b_n2_global_desc),
decltype( decltype(out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type, arithmetic_sequence_gen<0, 5, 1>::type,
3, 3,
1, 1,
1>({0, 0, 0, 0, 0}, 1,
AddressSpace::vgpr,
AddressSpace::global,
InMemoryDataOperation::none>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
0, 0,
b_thread_data_on_global, b_thread_data_on_global,
0}) 0})
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space); .Run(p_out_thread, p_out_global);
} }
} }
}; };
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
namespace ck { namespace ck {
// B = merge(N, Ho, Wo) // B = merge(N, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -61,11 +60,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
constexpr auto global_address_space =
integral_constant<AddressSpace, AddressSpace::global>{};
constexpr auto in_n_c_hi_wi_global_desc = constexpr auto in_n_c_hi_wi_global_desc =
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides()); make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
constexpr auto wei_k_c_y_x_global_desc = constexpr auto wei_k_c_y_x_global_desc =
...@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -158,7 +152,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
1, 1,
1, 1,
InBlockCopyDataPerAccess_B, InBlockCopyDataPerAccess_B,
InBlockCopyDataPerAccess_B>( InBlockCopyDataPerAccess_B,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, b_block_data_on_global}, {0, 0}); {0, b_block_data_on_global}, {0, 0});
// weight tensor // weight tensor
...@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -192,7 +190,11 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
0, 0,
1, 1,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>( WeiBlockCopyDstDataPerWrite_K,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, k_block_data_on_global}, {0, 0}); {0, k_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
...@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -202,7 +204,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc); constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc);
// sanity check // sanity check
...@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -260,10 +261,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
blockwise_in_copy.Run( blockwise_in_copy.Run(p_in_global, p_in_block_double);
p_in_global, p_in_block_double, global_address_space, generic_address_space); blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
blockwise_wei_copy.Run(
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
} }
// LDS double buffer: main body // LDS double buffer: main body
...@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -294,10 +293,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
...@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -323,10 +320,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
blockwise_in_copy.RunLoadThreadBuffer( blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer);
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer);
blockwise_wei_copy.RunLoadThreadBuffer(
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
...@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -397,17 +392,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
arithmetic_sequence_gen<0, 4, 1>::type, arithmetic_sequence_gen<0, 4, 1>::type,
3, 3,
OutThreadCopyDataPerAccess_B, OutThreadCopyDataPerAccess_B,
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0}, OutThreadCopyDataPerAccess_B,
AddressSpace::vgpr,
AddressSpace::global>({0, 0, 0, 0},
{k_thread_data_on_global / K1, {k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
b_thread_data_on_global / B1, b_thread_data_on_global / B1,
b_thread_data_on_global % B1}) b_thread_data_on_global % B1})
#if 1 .Run(p_out_thread, p_out_global);
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
#else // tweaking
.Run_optimized_dst_address_calculation(
p_out_thread, p_out_global, generic_address_space, global_address_space);
#endif
} }
} }
}; };
......
...@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto ...@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
template <typename... Ts> template <typename... Ts>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>) make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
{ {
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>; using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
......
...@@ -83,37 +83,16 @@ struct Pad ...@@ -83,37 +83,16 @@ struct Pad
__host__ __device__ constexpr bool __host__ __device__ constexpr bool
IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const IsUpperIndexMappedToValidLowerIndex(const UpperIndex& idx_up) const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{ {
bool flag = true; bool flag = true;
static_for<0, nDim, 1>{}([&](auto idim) { static_for<0, nDim, 1>{}([&](auto idim) {
// only check if there is left-padding flag = flag && (idx_up[idim] >= LeftPads::At(idim)) &&
static_if<(LeftPads::At(idim) != 0)>{}( (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
[&](auto) { flag = flag && idx_up[idim] >= LeftPads::At(idim); });
// only check if there is right-padding
static_if<(RightPads::At(idim) != 0)>{}([&](auto) {
flag = flag && (idx_up[idim] < LeftPads::At(idim) + LowerLengths::At(idim));
});
}); });
return flag; return flag;
} }
}
}; };
// LowerLengths: Sequence<...> // LowerLengths: Sequence<...>
......
...@@ -64,7 +64,7 @@ template <typename LowerTensorDescriptor, ...@@ -64,7 +64,7 @@ template <typename LowerTensorDescriptor,
index_t... LowerDimensionIds, index_t... LowerDimensionIds,
index_t... UpperDimensionIds> index_t... UpperDimensionIds>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
Sequence<LowerLengths...>, Sequence<LowerLengths...>,
Sequence<LowerDimensionIds...>, Sequence<LowerDimensionIds...>,
Sequence<UpperDimensionIds...>) Sequence<UpperDimensionIds...>)
...@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto ...@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto
// reorder a NativeTensorDescriptor // reorder a NativeTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
...@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto ...@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto
// reorder a TransformedTensorDescriptor // reorder a TransformedTensorDescriptor
template <typename... Ts, typename MapLower2Upper> template <typename... Ts, typename MapLower2Upper>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper) reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
{ {
static_assert(is_valid_sequence_map<MapLower2Upper>{}, static_assert(is_valid_sequence_map<MapLower2Upper>{},
"wrong! MapLower2Upper is not a valid map"); "wrong! MapLower2Upper is not a valid map");
......
...@@ -21,7 +21,11 @@ template <index_t BlockSize, ...@@ -21,7 +21,11 @@ template <index_t BlockSize,
index_t SrcVectorAccessDim, index_t SrcVectorAccessDim,
index_t DstVectorAccessDim, index_t DstVectorAccessDim,
index_t SrcDataPerAccess, index_t SrcDataPerAccess,
index_t DstDataPerAccess> index_t DstDataPerAccess,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace ThreadBufferAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
struct BlockwiseGenericTensorSliceCopy_v4 struct BlockwiseGenericTensorSliceCopy_v4
{ {
static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension();
...@@ -66,120 +70,57 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -66,120 +70,57 @@ struct BlockwiseGenericTensorSliceCopy_v4
return ThreadBufferDesc::GetElementSpace(); return ThreadBufferDesc::GetElementSpace();
} }
template <typename BlockSrcData, template <typename BlockSrcData, typename ThreadBufferData>
typename ThreadBufferData, __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
AddressSpace BlockSrcAddressSpace, ThreadBufferData* p_thread_buffer) const
AddressSpace ThreadBufferAddressSpace>
__device__ void
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
ThreadBufferData* p_thread_buffer,
integral_constant<AddressSpace, BlockSrcAddressSpace>,
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
{ {
constexpr auto block_src_address_space =
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
constexpr auto thread_buffer_address_space =
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
{ {
mThreadwiseLoad.Run_optimized_src_address_calculation( mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer);
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
} }
else else
{ {
mThreadwiseLoad.Run( mThreadwiseLoad.Run(p_block_src, p_thread_buffer);
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
} }
} }
template <typename BlockSrcData, typename ThreadBufferData> template <typename ThreadBufferData, typename BlockDstData>
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src, __device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
ThreadBufferData* p_thread_buffer) const BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
}
template <typename ThreadBufferData,
typename BlockDstData,
AddressSpace ThreadBufferAddressSpace,
AddressSpace BlockDstAddressSpace>
__device__ void
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
BlockDstData* p_block_dst,
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
integral_constant<AddressSpace, BlockDstAddressSpace>) const
{ {
constexpr auto thread_buffer_address_space =
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
constexpr auto block_dst_address_space =
integral_constant<AddressSpace, BlockDstAddressSpace>{};
constexpr bool has_optimized_address_calculation = constexpr bool has_optimized_address_calculation =
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
// TODO: threadwise copy is still being tweaked // TODO: threadwise copy is still being tweaked
if(has_optimized_address_calculation) if(has_optimized_address_calculation)
{ {
mThreadwiseStore.Run_optimized_dst_address_calculation( mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst);
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
} }
else else
{ {
mThreadwiseStore.Run( mThreadwiseStore.Run(p_thread_buffer, p_block_dst);
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
} }
} }
template <typename ThreadBufferData, typename BlockDstData> template <typename BlockSrcData, typename BlockDstData>
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
BlockDstData* p_block_dst) const
{ {
constexpr auto generic_address_space = static_assert(ThreadBufferAddressSpace == AddressSpace::vgpr,
integral_constant<AddressSpace, AddressSpace::generic>{}; "wrong! This function use vgpr as its thread "
"buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer "
"to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. "
"Behavior may be different");
RunStoreThreadBuffer(
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
}
template <typename BlockSrcData,
typename BlockDstData,
AddressSpace BlockSrcAddressSpace,
AddressSpace BlockDstAddressSpace>
__device__ void
Run(const BlockSrcData* p_block_src,
BlockDstData* p_block_dst,
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
{
BlockSrcData p_thread_buffer[GetThreadBufferSize()]; BlockSrcData p_thread_buffer[GetThreadBufferSize()];
constexpr auto generic_address_space = RunLoadThreadBuffer(p_block_src, p_thread_buffer);
integral_constant<AddressSpace, AddressSpace::generic>{};
RunLoadThreadBuffer(
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
// if there is type conversion, it's done during store // if there is type conversion, it's done during store
RunStoreThreadBuffer( RunStoreThreadBuffer(p_thread_buffer, p_block_dst);
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
}
template <typename BlockSrcData, typename BlockDstData>
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
} }
template <typename T, bool PositiveDirection> template <typename T, bool PositiveDirection>
...@@ -207,7 +148,10 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -207,7 +148,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
SrcDimAccessOrder, SrcDimAccessOrder,
SrcVectorAccessDim, SrcVectorAccessDim,
SrcDataPerAccess, SrcDataPerAccess,
1>; 1,
SrcAddressSpace,
ThreadBufferAddressSpace,
InMemoryDataOperation::none>;
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc, using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2<ThreadBufferDesc,
BlockDstDesc, BlockDstDesc,
...@@ -215,7 +159,10 @@ struct BlockwiseGenericTensorSliceCopy_v4 ...@@ -215,7 +159,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
DstDimAccessOrder, DstDimAccessOrder,
DstVectorAccessDim, DstVectorAccessDim,
1, 1,
DstDataPerAccess>; DstDataPerAccess,
ThreadBufferAddressSpace,
DstAddressSpace,
DstInMemOp>;
ThreadwiseLoad mThreadwiseLoad; ThreadwiseLoad mThreadwiseLoad;
ThreadwiseStore mThreadwiseStore; ThreadwiseStore mThreadwiseStore;
......
#ifndef CK_GRIDWISE_GEMM_HPP
#define CK_GRIDWISE_GEMM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN,
typename ABlockCopySubLengths_K_M,
typename ABlockCopyClusterLengths_K_M,
index_t ABlockCopyDataPerAccess_M,
typename BBlockCopySubLengths_K_N,
typename BBlockCopyClusterLengths_K_N,
index_t BBlockCopyDataPerAccess_N,
index_t CThreadCopyDataPerAccess_N>
struct GridwiseGemmTransposedANormalBNormalC_v1r1
{
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
BBlockCopyDataPerAccess_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopySubLengths_K_M,
ABlockCopyClusterLengths_K_M,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
ABlockCopyDataPerAccess_M,
ABlockCopyDataPerAccess_M,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopySubLengths_K_N,
BBlockCopyClusterLengths_K_N,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
BBlockCopyDataPerAccess_N,
BBlockCopyDataPerAccess_N,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat =
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThreadSubC,
NPerThreadSubC,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThreadLoop,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>,
3,
CThreadCopyDataPerAccess_N,
CThreadCopyDataPerAccess_N,
AddressSpace::vgpr,
AddressSpace::global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
.Run(p_c_thread, p_c_global);
}
}
};
} // namespace ck
#endif
...@@ -21,7 +21,10 @@ template <typename SrcDesc, ...@@ -21,7 +21,10 @@ template <typename SrcDesc,
typename DimAccessOrder, typename DimAccessOrder,
index_t VectorAccessDim, index_t VectorAccessDim,
index_t SrcDataPerAccess, index_t SrcDataPerAccess,
index_t DstDataPerAccess> index_t DstDataPerAccess,
AddressSpace SrcAddressSpace = AddressSpace::generic,
AddressSpace DstAddressSpace = AddressSpace::generic,
InMemoryDataOperation DstInMemOp = InMemoryDataOperation::none>
struct ThreadwiseGenericTensorSliceCopy_v4r2 struct ThreadwiseGenericTensorSliceCopy_v4r2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
...@@ -66,18 +69,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -66,18 +69,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do padding check on dst data: No write if dst data is in paddin area.
template <typename SrcData, template <typename SrcData, typename DstData>
typename DstData, __device__ void Run(const SrcData* p_src, DstData* p_dst) const
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace>
__device__ void Run(const SrcData* p_src,
DstData* p_dst,
integral_constant<AddressSpace, SrcAddressSpace>,
integral_constant<AddressSpace, DstAddressSpace>) const
{ {
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
...@@ -120,20 +114,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -120,20 +114,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the same padding situation // has the same padding situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) { move_data<SrcData,
#if CK_USE_AMD_BUFFER_ADDRESSING SrcDataPerAccess,
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) = SrcAddressSpace,
amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>( AddressSpace::vgpr,
fwd(p_src), src_coord.GetOffset(), 0); InMemoryDataOperation::none>(
#else p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
#endif
}).Else([&](auto) {
// src can be all kinds of memory-space.
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
});
} }
} }
...@@ -160,36 +146,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -160,36 +146,17 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the same padding situation // has the same padding situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) { move_data<DstData,
#if CK_USE_AMD_BUFFER_ADDRESSING DstDataPerAccess,
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>( AddressSpace::vgpr,
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]), DstAddressSpace,
fwd(p_dst), DstInMemOp>(
dst_coord.GetOffset(), p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
0);
#else
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
#endif
}).Else([&](auto) {
// dst can be all kinds of memory-space
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
});
} }
} }
}); });
} }
template <typename SrcData, typename DstData>
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
{
constexpr auto generic_address_space =
integral_constant<AddressSpace, AddressSpace::generic>{};
Run(p_src, p_dst, generic_address_space, generic_address_space);
}
// Modify Length to 1, if Mask is set to false // Modify Length to 1, if Mask is set to false
// Used for isolating linear dimension from non-linear dimensions // Used for isolating linear dimension from non-linear dimensions
template <index_t... Lengths, index_t... Mask> template <index_t... Lengths, index_t... Mask>
...@@ -198,26 +165,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -198,26 +165,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
return Sequence<(Mask ? Lengths : 1)...>{}; return Sequence<(Mask ? Lengths : 1)...>{};
} }
// p_src must be global-memory, p_dst can be any memory-space.
// User should make sure p_src is a block-invariant pointer, because
// buffer_load is used for loading from global-memory into register buffer.
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of src tensor // This version is optimized for address calculation of src tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, template <typename SrcData, typename DstData>
typename DstData, __device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
AddressSpace SrcAddressSpace, DstData* p_dst) const
AddressSpace DstAddressSpace>
__device__ void
Run_optimized_src_address_calculation(const SrcData* p_src,
DstData* p_dst,
integral_constant<AddressSpace, SrcAddressSpace>,
integral_constant<AddressSpace, DstAddressSpace>) const
{ {
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
...@@ -308,21 +263,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -308,21 +263,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// the src vector has the same padding situation // the src vector has the same padding situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto) { move_data<SrcData,
#if CK_USE_AMD_BUFFER_ADDRESSING SrcDataPerAccess,
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) = SrcAddressSpace,
amd_intrinsic_buffer_load<SrcData, SrcDataPerAccess>( AddressSpace::vgpr,
p_src, src_nonlinear_coord.GetOffset(), src_linear_offset); InMemoryDataOperation::none>(p_src,
#else src_nonlinear_coord.GetOffset() +
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) = src_linear_offset,
*reinterpret_cast<const src_vector_t*>( p_src_long_vector,
&p_src[src_nonlinear_coord.GetOffset() + src_linear_offset]); buffer_offset);
#endif
}).Else([&](auto) {
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
*reinterpret_cast<const src_vector_t*>(
&p_src[src_nonlinear_coord.GetOffset() + src_linear_offset]);
});
} }
} }
...@@ -352,34 +301,26 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -352,34 +301,26 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// the dst vector has the same padding situation // the dst vector has the same padding situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_coord.GetOffset()]) = move_data<DstData,
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]); DstDataPerAccess,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
} }
} }
}); });
}); });
} }
// p_src could be any memory space, d_dst must be global memory.
// User should make sure p_dst is a block-invariant pointer, because
// buffer_load is used for storing data from regsiter buffer into global-memory.
// Will do padding check on src data: Read 0 if src data is in padding area. // Will do padding check on src data: Read 0 if src data is in padding area.
// Will do padding check on dst data: No write if dst data is in paddin area. // Will do padding check on dst data: No write if dst data is in paddin area.
// This version is optimized for address calculation of dst tensor // This version is optimized for address calculation of dst tensor
// TODO: this function is not compiled to expected ISA // TODO: this function is not compiled to expected ISA
template <typename SrcData, template <typename SrcData, typename DstData>
typename DstData, __device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
AddressSpace SrcAddressSpace, DstData* p_dst) const
AddressSpace DstAddressSpace>
__device__ void
Run_optimized_dst_address_calculation(const SrcData* p_src,
DstData* p_dst,
integral_constant<AddressSpace, SrcAddressSpace>,
integral_constant<AddressSpace, DstAddressSpace>) const
{ {
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
constexpr auto vector_access_dim = Number<VectorAccessDim>{}; constexpr auto vector_access_dim = Number<VectorAccessDim>{};
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{}; constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
...@@ -461,8 +402,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -461,8 +402,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// the src vector has the same padding situation // the src vector has the same padding situation
if(src_coord.IsUpperIndexMappedToValidOffset()) if(src_coord.IsUpperIndexMappedToValidOffset())
{ {
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) = move_data<SrcData,
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]); SrcDataPerAccess,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
} }
} }
...@@ -501,23 +446,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 ...@@ -501,23 +446,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// the dst vector has the same padding situation // the dst vector has the same padding situation
if(dst_coord.IsUpperIndexMappedToValidOffset()) if(dst_coord.IsUpperIndexMappedToValidOffset())
{ {
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto) { move_data<DstData,
#if CK_USE_AMD_BUFFER_ADDRESSING DstDataPerAccess,
amd_intrinsic_buffer_store<DstData, DstDataPerAccess>( AddressSpace::vgpr,
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]), DstAddressSpace,
DstInMemOp>(p_dst_long_vector,
buffer_offset,
p_dst, p_dst,
dst_nonlinear_coord.GetOffset(), dst_nonlinear_coord.GetOffset() + dst_linear_offset);
dst_linear_offset);
#else
*reinterpret_cast<dst_vector_t*>(
&p_dst[dst_nonlinear_coord.GetOffset() + dst_linear_offset]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
#endif
}).Else([&](auto) {
*reinterpret_cast<dst_vector_t*>(
&p_dst[dst_nonlinear_coord.GetOffset() + dst_linear_offset]) =
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
});
} }
} }
}); });
......
...@@ -54,10 +54,18 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata, ...@@ -54,10 +54,18 @@ __device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
bool glc, bool glc,
bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); bool slc) __asm("llvm.amdgcn.buffer.store.v4f32");
// buffer_load requires:
// 1) p_src must be in global memory space, d_dst must be vgpr
// 2) p_src to be a block-invariant pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_load( __device__ typename vector_type<T, VectorSize>::MemoryType amd_intrinsic_buffer_load(
const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset);
// buffer_store requires:
// 1) p_src must be in vgpr space, d_dst must be global memory
// 2) p_dst to be a block-invariant pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t VectorSize> template <typename T, index_t VectorSize>
__device__ void __device__ void
amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src, amd_intrinsic_buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "in_memory_operation.hpp"
#if CK_USE_AMD_INLINE_ASM #if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp" #include "amd_inline_asm.hpp"
......
...@@ -54,7 +54,15 @@ namespace ck { ...@@ -54,7 +54,15 @@ namespace ck {
enum AddressSpace enum AddressSpace
{ {
generic, generic,
global global,
lds,
vgpr
};
enum InMemoryDataOperation
{
none,
atomic_add
}; };
#if CK_UNSIGNED_INDEX_TYPE #if CK_UNSIGNED_INDEX_TYPE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment