Commit 5a88b0b0 authored by Chao Liu's avatar Chao Liu
Browse files

int8x4 for v4r4 nchw

parent 9c3f33c0
...@@ -3,7 +3,10 @@ ...@@ -3,7 +3,10 @@
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <class T, template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc,
class TOut,
class InDesc, class InDesc,
class WeiDesc, class WeiDesc,
class OutDesc, class OutDesc,
...@@ -11,35 +14,33 @@ template <class T, ...@@ -11,35 +14,33 @@ template <class T,
class ConvDilations, class ConvDilations,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
const Tensor<T>& in_nchw, InDesc,
const Tensor<TInWei>& in_n_c_hi_wi,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<TInWei>& wei_k_c_y_x,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<TOut>& out_n_k_ho_wo,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
InRightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<< std::endl;
using namespace ck; using namespace ck;
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type; std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
<< std::endl;
std::size_t data_sz = sizeof(T); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
#if 1 #if 0
// run-time variables // run-time variables
const auto in_n_c_hi_wi_desc = const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths()));
...@@ -67,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -67,7 +68,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
#if 0 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -368,9 +369,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -368,9 +369,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <BlockSize,
TDevice, typename vector_type<TInWei, InWeiVectorSize>::type,
TDevice, TAcc,
TDevice, TOut,
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
...@@ -398,9 +399,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -398,9 +399,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
} }
...@@ -27,11 +27,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -27,11 +27,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
InRightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
<< std::endl; << std::endl;
using namespace ck;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
...@@ -373,7 +373,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -373,7 +373,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
#endif #endif
constexpr auto conv_driver = constexpr auto conv_driver =
#if 0 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
......
...@@ -21,7 +21,7 @@ int main(int argc, char* argv[]) ...@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
...@@ -704,8 +704,13 @@ int main(int argc, char* argv[]) ...@@ -704,8 +704,13 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size,
acc_data_t,
out_data_t>
(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment