Commit 8b5e63ed authored by Jing Zhang's avatar Jing Zhang
Browse files

mock up

parent 95a5af02
...@@ -10,8 +10,9 @@ ...@@ -10,8 +10,9 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -42,9 +43,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -42,9 +43,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const Float* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -166,8 +167,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -166,8 +167,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize, BlockSize,
Float, FloatAB,
AccFloat, FloatAcc,
FloatC,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
...@@ -227,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -227,11 +229,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -254,11 +256,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -254,11 +256,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
Float*, FloatC*,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -281,11 +283,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -281,11 +283,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -308,11 +310,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -308,11 +310,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc), decltype(wei_gemmk_gemmm_global_desc),
const Float*, const FloatAB*,
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
const Float*, const FloatAB*,
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
Float*, FloatC*,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
namespace ck { namespace ck {
template <index_t BlockSize, template <index_t BlockSize,
typename Float, typename FloatAB,
typename AccFloat, typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AGlobalDesc,
typename BGlobalDesc, typename BGlobalDesc,
...@@ -64,17 +65,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -64,17 +65,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(Float); return a_block_space_size * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -177,8 +178,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -177,8 +178,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
ABlockTransferThreadSliceLengths_E_K, ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K, ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
Float, FloatAB,
Float, FloatAB,
decltype(a_e_k_global_desc), decltype(a_e_k_global_desc),
decltype(a_e_k_desc), decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -203,8 +204,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -203,8 +204,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
Float, FloatAB,
Float, FloatAB,
decltype(b_e_n_ho_wo_global_desc), decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc), decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>, Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
...@@ -218,10 +219,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -218,10 +219,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
true>(b_e_n_ho_wo_global_desc, true>(b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
Float* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread); threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
...@@ -240,9 +241,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -240,9 +241,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize(); constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
Float p_b_thread[b_thread_space_size * 2]; FloatAB p_b_thread[b_thread_space_size * 2];
Float* p_b_thread_double = p_b_thread; FloatAB* p_b_thread_double = p_b_thread;
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
...@@ -265,8 +266,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -265,8 +266,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
#if 1 #if 1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
Float* p_b_thread_even = p_b_thread_double; FloatAB* p_b_thread_even = p_b_thread_double;
Float* p_b_thread_odd = p_b_thread_double + b_thread_space_size; FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
// LDS double buffer: main body // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow // use Do-While loop instead of For loop to simplify control flow
...@@ -359,8 +360,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -359,8 +360,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_block_data_on_global + k_thread_id * KPerThread; k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat, FloatAcc,
Float, FloatC,
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc), decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>, Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
...@@ -388,17 +389,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -388,17 +389,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc, __device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc, const BGlobalDesc& b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc, const CGlobalDesc& c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ Float p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc, Run(a_e_k_global_desc,
p_a_global, p_a_global,
...@@ -414,11 +415,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -414,11 +415,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by their pointers // pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc, const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc, const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -439,11 +440,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -439,11 +440,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// pass tensor descriptors by void* // pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc, __device__ void Run(const void* p_a_e_k_global_desc,
const Float* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc, const void* p_b_e_n_ho_wo_global_desc,
const Float* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc, const void* p_c_k_n_ho_wo_global_desc,
Float* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
......
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
template <class T, template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class OutDesc, class OutDesc,
...@@ -11,33 +14,31 @@ template <class T, ...@@ -11,33 +14,31 @@ template <class T,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc, void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
const Tensor<T>& in_nchw, InDesc,
WeiDesc, const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<T>& wei_kcyx, WeiDesc,
OutDesc, const Tensor<TInWei>& wei_k_c_y_x,
Tensor<T>& out_nkhw, OutDesc,
ConvStrides, Tensor<TOut>& out_n_k_ho_wo,
ConvDilations, ConvStrides,
InLeftPads, ConvDilations,
InRightPads, InLeftPads,
ck::index_t nrepeat) InRightPads,
ck::index_t nrepeat)
{ {
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw"
<< std::endl;
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<< std::endl;
std::size_t data_sz = sizeof(T); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 0 #if 0
// run-time variables // run-time variables
...@@ -70,18 +71,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -70,18 +71,18 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 4;
constexpr index_t HoPerBlock = 16; constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16; constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4; constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16; constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4; constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<1, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<4, 16>; using ABlockTransferThreadClusterLengths_E_K = Sequence<9, 4>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
...@@ -93,8 +94,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -93,8 +94,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
BlockSize, BlockSize,
TDevice, typename vector_type<TInWei, InWeiVectorSize>::type,
TDevice, TAcc,
TOut,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -117,9 +119,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -117,9 +119,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
} }
...@@ -80,10 +80,10 @@ int main(int argc, char* argv[]) ...@@ -80,10 +80,10 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 1;
constexpr index_t HI = 1024; constexpr index_t HI = 1024;
constexpr index_t WI = 2048; constexpr index_t WI = 2048;
constexpr index_t K = 16; constexpr index_t K = 4;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -630,7 +630,7 @@ int main(int argc, char* argv[]) ...@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -754,17 +754,21 @@ int main(int argc, char* argv[]) ...@@ -754,17 +754,21 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(in_nchw_desc, device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
in_nchw, in_vector_size,
wei_kcyx_desc, acc_data_t,
wei_kcyx, out_data_t>(
out_nkhw_desc, in_nchw_desc,
out_nkhw_device, in_nchw,
ConvStrides{}, wei_kcyx_desc,
ConvDilations{}, wei_kcyx,
LeftPads{}, out_nkhw_desc,
RightPads{}, out_nkhw_device,
nrepeat); ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif #endif
if(do_verification) if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment