Unverified Commit fcbb9788 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Dynamic tensor descriptor (#24)



* support dynamic tensor descriptor

* use buffer load OOB feature for padding case

* add navi support

* add int8x4 inference kernel
Co-authored-by: default avatarChao Liu <chao@ixt-rack-81.local.lan>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent bbcb67d0
...@@ -3,7 +3,7 @@ project(modular_convolution) ...@@ -3,7 +3,7 @@ project(modular_convolution)
#c++ #c++
enable_language(CXX) enable_language(CXX)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
...@@ -53,6 +53,7 @@ include_directories(BEFORE ...@@ -53,6 +53,7 @@ include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/half/include ${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility ${PROJECT_BINARY_DIR}/composable_kernel/include/utility
......
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
index_t ABlockTransferSrcScalarPerVector_E,
index_t ABlockTransferDstScalarPerVector_K,
index_t BThreadTransferSrcScalarPerVector_W,
index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)),
make_tuple(make_pass_through_transform(K),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
E % EPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_k_m_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
constexpr auto b_k_n_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
#if 1
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>,
Sequence<1, 0>,
0,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>,
3,
BThreadTransferSrcScalarPerVector_W,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>,
3,
CThreadTransferDstScalarPerVector_W,
decltype(a_k_m_global_iterator_hacks),
decltype(b_k_n_global_iterator_hacks),
decltype(c_k_n_h_w_global_tensor_iterator_hacks),
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k_ho_wo_global_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
} // namespace ck
#endif
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER #define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template <typename GridwiseOp, typename... Xs> template <typename GridwiseOp, typename... Xs>
__global__ void run_gridwise_operation(Xs... xs) __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
run_gridwise_operation(Xs... xs)
{ {
GridwiseOp{}.Run(xs...); GridwiseOp{}.Run(xs...);
} }
......
...@@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t e_block_data_on_global = block_work_id[0] * EPerBlock; const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock;
// output tensor // output tensor
// global tensor in global memory, src of blockwise copy // global tensor in global memory, src of blockwise copy
...@@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>(
{0, b_block_data_on_global, 0}, {0, 0, 0}); make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0));
// weight tensor // weight tensor
// global tensor in global memory, src of blockwise copy // global tensor in global memory, src of blockwise copy
...@@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>(
{0, e_block_data_on_global, 0}, {0, 0, 0}); make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -354,7 +354,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -354,7 +354,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
{ {
#if 1 // debug #if 1 // debug
// input: register to global memory, atomic add // input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::Set ? InMemoryDataOperation::Set
: InMemoryDataOperation::AtomicAdd; : InMemoryDataOperation::AtomicAdd;
...@@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
InThreadCopyDstDataPerWrite_B, InThreadCopyDstDataPerWrite_B,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Global, AddressSpace::Global,
in_memory_op>({0, 0, 0, 0, 0, 0}, in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0),
{e_thread_data_on_global / E1, make_multi_index(e_thread_data_on_global / E1,
e_thread_data_on_global % E1, e_thread_data_on_global % E1,
0, 0,
b_thread_data_on_global / B1, b_thread_data_on_global / B1,
b_thread_data_on_global % B1, b_thread_data_on_global % B1,
0}) 0))
.Run(p_in_thread, p_in_global); .Run(p_in_thread, p_in_global);
} }
} }
......
...@@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk ...@@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk
index_t GemmK1 = XDotSlice; index_t GemmK1 = XDotSlice;
index_t GemmK2 = K; index_t GemmK2 = K;
return Array<index_t, 5>{GemmM, GemmN, GemmK0, GemmK1, GemmK2}; return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2);
} }
__host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id)
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -49,7 +49,7 @@ template <index_t GridSize, ...@@ -49,7 +49,7 @@ template <index_t GridSize,
typename WeiBlockCopyDstAccessOrder, typename WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -119,8 +119,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -119,8 +119,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; const index_t k_block_data_on_global = block_work_id[I0] * KPerBlock;
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; const index_t b_block_data_on_global = block_work_id[I1] * BPerBlock;
// input tensor // input tensor
// global tensor in global memory // global tensor in global memory
...@@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>(
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0));
// weight tensor // weight tensor
// global tensor in global memory, src of blockwise copy // global tensor in global memory, src of blockwise copy
...@@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>(
{0, k_block_data_on_global}, {0, 0}); make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -439,12 +439,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -439,12 +439,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
1, 1,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Global, AddressSpace::Global,
InMemoryDataOperation::Set>({0, 0, 0, 0, 0}, InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0),
{k_thread_data_on_global / K1, make_multi_index(k_thread_data_on_global / K1,
k_thread_data_on_global % K1, k_thread_data_on_global % K1,
0, 0,
b_thread_data_on_global, b_thread_data_on_global,
0}) 0))
.Run(p_out_thread, p_out_global); .Run(p_out_thread, p_out_global);
} }
} }
......
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -43,7 +43,7 @@ template <index_t GridSize, ...@@ -43,7 +43,7 @@ template <index_t GridSize,
index_t GemmBBlockCopySrcDataPerRead_GemmN, index_t GemmBBlockCopySrcDataPerRead_GemmN,
index_t GemmBBlockCopyDstDataPerWrite_GemmN, index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1> index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
......
#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename InGlobalDesc,
typename WeiGlobalDesc,
typename OutGlobalDesc,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
index_t ThreadGemmDataPerRead_GemmM,
index_t ThreadGemmDataPerRead_GemmN,
typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockCopySrcDataPerRead_GemmK,
index_t GemmABlockCopyDstDataPerWrite_GemmM,
typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockCopySrcDataPerRead_GemmK,
index_t GemmBBlockCopyDstDataPerWrite_GemmN,
index_t GemmCThreadCopyDstDataPerWrite_GemmM1>
struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk
{
__device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{};
constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{};
constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[I0];
constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[I1];
constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[I2];
constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[I3];
constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[I3];
constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[I1];
constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[I2];
constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[I1];
constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[I2];
constexpr index_t ConvStrideH = ConvStrides{}[I0];
constexpr index_t ConvStrideW = ConvStrides{}[I1];
constexpr index_t ConvDilationH = ConvDilations{}[I0];
constexpr index_t ConvDilationW = ConvDilations{}[I1];
// weight tensor
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{});
// input tensor
constexpr auto in_n_hip_wip_c_global_desc =
transform_tensor_descriptor(in_n_hi_wi_c_global_desc,
make_tuple(PassThrough<N>{},
Pad<Sequence<Hi, Wi>, InLeftPads, InRightPads>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1];
constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2];
constexpr auto in_n_y_ho_x_wo_c_global_desc = transform_tensor_descriptor(
in_n_hip_wip_c_global_desc,
make_tuple(PassThrough<N>{},
Embed<Hip, Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wip, Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{},
PassThrough<C>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_global_desc,
make_tuple(Merge<Sequence<Y, X, C>>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2),
make_tuple(PassThrough<K>{}, Merge<Sequence<N * Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
constexpr auto gridwise_gemm =
GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
BlockSize,
Float,
AccFloat,
decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_gemmn_global_desc),
decltype(out_gemmm_gemmn_global_desc),
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
ThreadGemmDataPerRead_GemmM,
ThreadGemmDataPerRead_GemmN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockCopySrcDataPerRead_GemmK,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockCopySrcDataPerRead_GemmK,
GemmBBlockCopyDstDataPerWrite_GemmN,
Sequence<2, 3, 0, 1>,
1,
GemmCThreadCopyDstDataPerWrite_GemmM1>{};
gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
}
};
} // namespace ck
#endif
#ifndef CK_ARRAY_MULTI_INDEX_HPP
#define CK_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
__host__ __device__ constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
__host__ __device__ constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
__host__ __device__ constexpr auto operator+=(MultiIndex<NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <index_t NSize, typename X>
__host__ __device__ constexpr auto operator-=(MultiIndex<NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator+(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator-(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator*(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; });
return r;
}
} // namespace ck
#endif
#ifndef CK_CLUSTER_DESCRIPTOR_HPP
#define CK_CLUSTER_DESCRIPTOR_HPP
#include "common_header.hpp"
// TODO remove dependency on deprecated tensor descriptor
#include "tensor_descriptor.hpp"
namespace ck {
// a cluster map 1d index to N-d index
template <typename Lengths, typename ArrangeOrder>
struct ClusterDescriptor
{
static constexpr index_t nDim = Lengths::Size();
static constexpr auto mDesc = transform_tensor_descriptor(
make_native_tensor_descriptor_packed(Lengths{}),
make_tuple(Merge<decltype(Lengths::ReorderGivenNew2Old(ArrangeOrder{}))>{}),
make_tuple(ArrangeOrder{}),
make_tuple(Sequence<0>{}));
__host__ __device__ constexpr ClusterDescriptor()
{
static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim,
"wrong! size not the same");
static_assert(is_valid_sequence_map<ArrangeOrder>{}, "wrong! ArrangeOrder is wrong");
}
__host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); }
__host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d)
{
return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d});
}
};
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
return ClusterDescriptor<Lengths, decltype(order)>{};
}
} // namespace ck
#endif
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform.hpp"
namespace ck {
template <typename LowLength>
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
{
return DynamicPassThrough<LowLength>{low_length};
}
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto
make_pad_transform(const LowLength& low_length,
const LeftPad& left_pad,
const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return DynamicPad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{
low_length, left_pad, right_pad};
}
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto make_left_pad_transform(
const LowLength& low_length,
const LeftPad& left_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return DynamicLeftPad<LowLength, LeftPad, SkipIsValidCheck>{low_length, left_pad};
}
template <typename LowLength, typename RightPad, bool SkipIsValidCheck>
__host__ __device__ constexpr auto make_right_pad_transform(
const LowLength& low_length,
const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return DynamicRightPad<LowLength, RightPad, SkipIsValidCheck>{low_length, right_pad};
}
template <typename UpLengths,
typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients)
{
return DynamicEmbed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return DynamicMerge<LowLengths>{low_lengths};
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths,
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
{
return DynamicUnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{
return DynamicFreeze<LowerIndex>{low_idx};
}
} // namespace ck
#endif
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
namespace ck {
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
#if CK_WORKAROUND_SWDEV_275126
template <typename Lengths, typename Strides, index_t I, typename AccOld>
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
Number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < Lengths::Size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
}
#endif
template <typename... Lengths,
typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides)
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto f = [&](auto fs, auto i, auto acc_old) {
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < N - 1)
{
return fs(fs, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
};
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
#else
const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
// Lengths... can be:
// 1) index_t, which is known at run-time
// 2) Number<>, which is known at compile-time
template <typename... Lengths>
__host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_unmerge_transform(lengths));
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
{
constexpr index_t N = sizeof...(Lengths);
auto strides = generate_tuple(
[&](auto i) {
if constexpr(i.value == N - 1)
{
return Number<1>{};
}
else if constexpr(i.value == N - 2)
{
return math::lcm(lengths[Number<N - 1>{}], align);
}
else
{
return container_reduce(lengths,
math::multiplies_v2{},
math::lcm(lengths[Number<N - 1>{}], align),
i,
Number<N - 2>{},
Number<1>{});
}
},
Number<N>{});
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides);
}
} // namespace ck
#endif
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment