"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "35da37f964de08bc64432935d4d8384a9309dc11"
Commit 0ca0103c authored by Chao Liu's avatar Chao Liu
Browse files

debugging int8x4_t

parent 3443835c
...@@ -31,7 +31,35 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz ...@@ -31,7 +31,35 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
return wave_buffer_resource.data; return wave_buffer_resource.data;
} }
// fp32 load // load
__device__ int8_t
__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int16_t
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
__device__ float __device__ float
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, __llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset, index_t voffset,
...@@ -50,7 +78,42 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, ...@@ -50,7 +78,42 @@ __llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// fp32 store // store
__device__ void
__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
__device__ void __device__ void
__llvm_amdgcn_raw_buffer_store_fp32(float vdata, __llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -72,41 +135,6 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, ...@@ -72,41 +135,6 @@ __llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
// i32 load
__device__ int32_t
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// i32 store
__device__ void
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
// i16 store
__device__ void
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
template <typename T, index_t N> template <typename T, index_t N>
__device__ typename vector_type<T, N>::type __device__ typename vector_type<T, N>::type
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
...@@ -187,7 +215,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -187,7 +215,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
{ {
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) || static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 2)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented"); "wrong! not implemented");
if constexpr(is_same<T, float>::value) if constexpr(is_same<T, float>::value)
...@@ -246,7 +274,15 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -246,7 +274,15 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
} }
else if constexpr(is_same<T, int8_t>::value) else if constexpr(is_same<T, int8_t>::value)
{ {
if constexpr(N == 2) if constexpr(N == 1)
{
__llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{ {
__llvm_amdgcn_raw_buffer_store_i16(src_thread_data, __llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
...@@ -254,6 +290,14 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type ...@@ -254,6 +290,14 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
dst_wave_addr_offset, dst_wave_addr_offset,
0); 0);
} }
else if constexpr(N == 4)
{
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
} }
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize,
class TAcc, class TAcc,
class TOut, class TOut,
class InDesc, class InDesc,
...@@ -14,17 +15,18 @@ template <class TInWei, ...@@ -14,17 +15,18 @@ template <class TInWei,
class InLeftPads, class InLeftPads,
class InRightPads, class InRightPads,
class T> class T>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
const Tensor<T>& in_nchw, InDesc,
WeiDesc, const Tensor<T>& in_n_c_hi_wi,
const Tensor<T>& wei_kcyx, WeiDesc,
OutDesc, const Tensor<T>& wei_k_c_y_x,
Tensor<T>& out_nkhw, OutDesc,
ConvStrides, Tensor<T>& out_n_k_ho_wo,
ConvDilations, ConvStrides,
InLeftPads, ConvDilations,
InRightPads, InLeftPads,
ck::index_t nrepeat) InRightPads,
ck::index_t nrepeat)
{ {
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
<< std::endl; << std::endl;
...@@ -49,6 +51,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -49,6 +51,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
constexpr auto Y = WeiDesc::GetLengths()[I2]; constexpr auto Y = WeiDesc::GetLengths()[I2];
constexpr auto X = WeiDesc::GetLengths()[I3]; constexpr auto X = WeiDesc::GetLengths()[I3];
constexpr auto C0 = C / Number<InWeiVectorSize>{};
constexpr auto C1 = Number<InWeiVectorSize>{};
#if 0 #if 0
// run-time variables // run-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c_desc =
...@@ -64,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -64,10 +69,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = to_multi_index(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
#else #else
// compile-time variables // compile-time variables
constexpr auto in_n_hi_wi_c_desc = constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0));
constexpr auto wei_k_y_x_c_desc = constexpr auto wei_k_y_x_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0));
constexpr auto out_n_ho_wo_k_desc = constexpr auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K));
...@@ -77,36 +82,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -77,36 +82,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
#endif #endif
Tensor<TInWei> in_nhwc( Tensor<TInWei> in_n_hi_wi_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
Tensor<TInWei> wei_kyxc( Tensor<TInWei> wei_k_y_x_c(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
Tensor<TOut> out_nhwk( Tensor<TOut> out_n_ho_wo_k(
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{}))); make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi);
}; };
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x);
}; };
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo);
}; };
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(); make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(); make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)();
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(); make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)();
DeviceMem in_nhwc_device_buf(sizeof(TInWei) * in_nhwc.mDesc.GetElementSpace()); DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_kyxc_device_buf(sizeof(TInWei) * wei_kyxc.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_nhwk_device_buf(sizeof(TOut) * out_nhwk.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); #if 1
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); #endif
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
#if 1 #if 1
// cdata = 16, BlockSize = 64, 16x64x4 // cdata = 16, BlockSize = 64, 16x64x4
...@@ -377,7 +384,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -377,7 +384,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <BlockSize,
TInWei, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
TOut, TOut,
GemmMPerBlock, GemmMPerBlock,
...@@ -400,21 +407,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc ...@@ -400,21 +407,25 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmCThreadTransferDstScalarPerVector_GemmM1>{};
conv_driver.Run(wei_k_y_x_c_desc, conv_driver.Run(wei_k_y_x_c0_desc,
in_n_hi_wi_c_desc, in_n_hi_wi_c0_desc,
out_n_ho_wo_k_desc, out_n_ho_wo_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TInWei*>(wei_kyxc_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TInWei*>(in_nhwc_device_buf.GetDeviceBuffer()), wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_nhwk_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); #if 1
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
#endif
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k);
}; };
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(); make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)();
......
...@@ -23,7 +23,21 @@ int main(int argc, char* argv[]) ...@@ -23,7 +23,21 @@ int main(int argc, char* argv[])
#if 0 #if 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 1;
constexpr index_t WI = 64;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
constexpr index_t WI = 1920; constexpr index_t WI = 1920;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -37,7 +51,7 @@ int main(int argc, char* argv[]) ...@@ -37,7 +51,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -51,7 +65,7 @@ int main(int argc, char* argv[]) ...@@ -51,7 +65,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -63,23 +77,9 @@ int main(int argc, char* argv[]) ...@@ -63,23 +77,9 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0
constexpr index_t N = 1;
constexpr index_t C = 4;
constexpr index_t HI = 1080;
constexpr index_t WI = 1920;
constexpr index_t K = 16;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1 #elif 1
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 1080; constexpr index_t HI = 1080;
constexpr index_t WI = 1920; constexpr index_t WI = 1920;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -93,7 +93,7 @@ int main(int argc, char* argv[]) ...@@ -93,7 +93,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 540; constexpr index_t HI = 540;
constexpr index_t WI = 960; constexpr index_t WI = 960;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -107,7 +107,7 @@ int main(int argc, char* argv[]) ...@@ -107,7 +107,7 @@ int main(int argc, char* argv[])
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 16;
constexpr index_t HI = 270; constexpr index_t HI = 270;
constexpr index_t WI = 480; constexpr index_t WI = 480;
constexpr index_t K = 16; constexpr index_t K = 16;
...@@ -629,12 +629,16 @@ int main(int argc, char* argv[]) ...@@ -629,12 +629,16 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1;
using out_data_t = float; using out_data_t = float;
using acc_data_t = float;
#else #else
using in_data_t = half_float::half; using in_data_t = int8_t;
using out_data_t = half_float::half; constexpr index_t in_vector_size = 4;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif #endif
Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc)); Tensor<in_data_t> in_nchw(make_HostTensorDescriptor(in_nchw_desc));
...@@ -665,12 +669,9 @@ int main(int argc, char* argv[]) ...@@ -665,12 +669,9 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
...@@ -730,15 +731,10 @@ int main(int argc, char* argv[]) ...@@ -730,15 +731,10 @@ int main(int argc, char* argv[])
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 1
#if 1 device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<float, float, float>( in_vector_size,
#elif 1 acc_data_t,
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, out_data_t>(
int32_t,
int32_t>(
#elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<int8x4_t, int32_t, int8_t>(
#endif
in_nchw_desc, in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -754,25 +750,19 @@ int main(int argc, char* argv[]) ...@@ -754,25 +750,19 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 0 #if 1
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && host_direct_convolution(in_nchw,
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) wei_kcyx,
{ out_nkhw_host,
host_winograd_3x3_convolution( ConvStrides{},
in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{}); ConvDilations{},
} LeftPads{},
else RightPads{});
#endif #endif
{
host_direct_convolution(in_nchw, #if 1
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{});
}
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
#endif
if(do_log) if(do_log)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment