Commit 08c00140 authored by Jing Zhang's avatar Jing Zhang
Browse files

int8

parent e273d4d3
...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -47,7 +47,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
FloatAB* __restrict__ p_d_global, FloatC* __restrict__ p_d_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -151,12 +151,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -151,12 +151,12 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
// add tensor // add tensor
const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor( const auto add_k_n_hopx2_wopx2_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, 1)), make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2)),
make_tuple(make_merge_transform(make_tuple(K0, 1)), make_tuple(make_pass_through_transform(K0),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pad_transform(Hox2, 0, AddRightPadH), make_pad_transform(Hox2, 0, AddRightPadH),
make_pad_transform(Wox2, 0, AddRightPadW)), make_pad_transform(Wox2, 0, AddRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X; const auto E = C * Y * X;
......
...@@ -382,12 +382,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -382,12 +382,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()]; FloatAB p_d_thread[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
constexpr auto vector_len = sizeof(FloatAB) / sizeof(FloatC); constexpr auto vector_len = CThreadTransferDstScalarPerVector;
static_assert(vector_len == CThreadTransferDstScalarPerVector); static_assert(vector_len == 16);
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
#if 0 #if 1
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatAB, FloatAB,
FloatAB, FloatAB,
...@@ -423,16 +423,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -423,16 +423,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
{ {
for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i) for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
{ {
vector_type<FloatC, vector_len> d_vec; vector_type<int8_t, vector_len> d_vec;
d_vec.Vector() = p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset( d_vec.Vector() = p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))]; make_tuple(k_i, 0, h_i, w_i))];
static_for<0, vector_len, 1>{}([&](auto i) { static_for<0, vector_len, 1>{}([&](auto i) {
d_vec.Scalars()(i) = 0; d_vec.Scalars()(i) += 1;
//p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset( // p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
//make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))]; // make_tuple(k_i * vector_len + i, 0, h_i / 2, w_i / 2))];
}); });
p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset( p_d_thread[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
make_tuple(k_i, 0, h_i, w_i))] = d_vec.Vector(); make_tuple(k_i, 0, h_i, w_i))] = d_vec.Vector();
...@@ -465,7 +465,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -465,7 +465,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
p_d_thread, p_d_thread,
d_k_n_hox2_wox2_global_desc, d_k_n_hox2_wox2_global_desc,
p_d_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
#endif #endif
} }
......
...@@ -91,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -91,8 +91,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
const auto add_n_k0_hox2_wox2_k1_desc = const auto add_n_k0_hox2_wox2_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2, 1)); make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Hox2, Wox2));
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
...@@ -156,7 +156,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -156,7 +156,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t CThreadTransferDstScalarPerVector_W = K1; constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); // static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else #else
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -192,7 +192,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -192,7 +192,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
<BlockSize, <BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
TOut, typename vector_type<TOut, InWeiVectorSize>::type,
KPerBlock, KPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
...@@ -210,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -210,7 +210,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_driver.Run(wei_k_c0_y_x_desc, conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc, in_n_c0_hi_wi_desc,
add_n_k0_hox2_wox2_k1_desc, add_n_k0_hox2_wox2_desc,
out_n_k0_ho_wo_k1_desc, out_n_k0_ho_wo_k1_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
...@@ -220,9 +220,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -220,9 +220,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()), add_n_k_hox2_wox2_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_hox2_wox2_device_buf.GetDeviceBuffer())); static_cast<typename vector_type<TOut, InWeiVectorSize>::type*>(
out_n_k_hox2_wox2_device_buf.GetDeviceBuffer()));
out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data()); out_n_k_hox2_wox2_device_buf.FromDevice(out_n_k0_hox2_wox2_k1.mData.data());
......
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
constexpr index_t N = 1; constexpr index_t N = 1;
constexpr index_t C = 4; constexpr index_t C = 4;
constexpr index_t HI = 64; constexpr index_t HI = 64;
...@@ -637,7 +637,7 @@ int main(int argc, char* argv[]) ...@@ -637,7 +637,7 @@ int main(int argc, char* argv[])
print_array("ConvStrides", to_multi_index(ConvStrides{})); print_array("ConvStrides", to_multi_index(ConvStrides{}));
print_array("ConvDilations", to_multi_index(ConvDilations{})); print_array("ConvDilations", to_multi_index(ConvDilations{}));
#if 1 #if 0
using in_data_t = float; using in_data_t = float;
constexpr index_t in_vector_size = 1; constexpr index_t in_vector_size = 1;
using acc_data_t = float; using acc_data_t = float;
...@@ -654,7 +654,7 @@ int main(int argc, char* argv[]) ...@@ -654,7 +654,7 @@ int main(int argc, char* argv[])
using out_data_t = int8_t; using out_data_t = int8_t;
#elif 1 #elif 1
using in_data_t = int8_t; using in_data_t = int8_t;
constexpr index_t in_vector_size = 4; constexpr index_t in_vector_size = 16;
using acc_data_t = int32_t; using acc_data_t = int32_t;
using out_data_t = int8_t; using out_data_t = int8_t;
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment