Commit b53926e9 authored by Jing Zhang's avatar Jing Zhang
Browse files

activ_type argument

parent fe427fd1
...@@ -35,7 +35,8 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -35,7 +35,8 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads,
index_t activ_type>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc, __host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc, const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc, const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
...@@ -43,6 +44,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -43,6 +44,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads_, const InRightPads& in_right_pads_,
Number<activ_type>,
const FloatAB* __restrict__ p_wei_global, const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global, const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const FloatC* __restrict__ p_out_global) const
...@@ -297,6 +299,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -297,6 +299,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
const FloatAB*, const FloatAB*,
FloatC*, FloatC*,
Number<activ_type>,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -308,6 +311,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -308,6 +311,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
p_in_global, p_in_global,
p_out_global, p_out_global,
Number<activ_type>{},
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -317,6 +321,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -317,6 +321,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
const FloatAB*, const FloatAB*,
FloatC*, FloatC*,
Number<activ_type>,
integral_constant<bool, true>, integral_constant<bool, true>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -328,6 +333,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -328,6 +333,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
p_in_global, p_in_global,
p_out_global, p_out_global,
Number<activ_type>{},
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
...@@ -337,6 +343,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -337,6 +343,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
const FloatAB*, const FloatAB*,
FloatC*, FloatC*,
Number<activ_type>,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
...@@ -348,6 +355,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -348,6 +355,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
p_in_global, p_in_global,
p_out_global, p_out_global,
Number<activ_type>{},
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
...@@ -357,6 +365,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -357,6 +365,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const FloatAB*, const FloatAB*,
const FloatAB*, const FloatAB*,
FloatC*, FloatC*,
Number<activ_type>,
integral_constant<bool, false>, integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
...@@ -368,6 +377,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad ...@@ -368,6 +377,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global, p_wei_global,
p_in_global, p_in_global,
p_out_global, p_out_global,
Number<activ_type>{},
integral_constant<bool, false>{}, integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
......
...@@ -72,11 +72,12 @@ struct GridwiseStaticGemm_km_kn_mn_v3 ...@@ -72,11 +72,12 @@ struct GridwiseStaticGemm_km_kn_mn_v3
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <index_t activ_type, bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const FloatAB* __restrict__ p_a_global, __device__ void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
Number<activ_type>,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -348,7 +349,6 @@ struct GridwiseStaticGemm_km_kn_mn_v3 ...@@ -348,7 +349,6 @@ struct GridwiseStaticGemm_km_kn_mn_v3
// activ // activ
{ {
constexpr index_t activ_type = 2;
static_for<0, c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), 1>{}([&](auto i) { static_for<0, c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1) if constexpr(activ_type == 1)
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0; c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
...@@ -392,10 +392,11 @@ struct GridwiseStaticGemm_km_kn_mn_v3 ...@@ -392,10 +392,11 @@ struct GridwiseStaticGemm_km_kn_mn_v3
} }
// pass tensor descriptor by reference // pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <index_t activ_type, bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const FloatAB* __restrict__ p_a_global, __device__ void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
Number<activ_type>,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
...@@ -407,6 +408,7 @@ struct GridwiseStaticGemm_km_kn_mn_v3 ...@@ -407,6 +408,7 @@ struct GridwiseStaticGemm_km_kn_mn_v3
p_b_global, p_b_global,
p_c_global, p_c_global,
p_shared_block, p_shared_block,
Number<activ_type>{},
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
......
...@@ -437,6 +437,8 @@ int main(int argc, char* argv[]) ...@@ -437,6 +437,8 @@ int main(int argc, char* argv[])
} }
#endif #endif
constexpr ck::index_t activ_type = 2;
#if USE_CONV_FWD_V5R1_NCHW #if USE_CONV_FWD_V5R1_NCHW
if(algo == ConvForwardAlgo::V5R1NCHW) if(algo == ConvForwardAlgo::V5R1NCHW)
{ {
...@@ -452,17 +454,17 @@ int main(int argc, char* argv[]) ...@@ -452,17 +454,17 @@ int main(int argc, char* argv[])
#else #else
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
#endif #endif
<in_data_t, 8, 8, acc_data_t, out_data_t>(tmp[I0], <in_data_t, 8, 8, activ_type, acc_data_t, out_data_t>(tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
tmp[I4], tmp[I4],
tmp[I5], tmp[I5],
tmp[I6], tmp[I6],
in, in,
wei, wei,
out_device, out_device,
nrepeat); nrepeat);
} }
#endif #endif
...@@ -529,8 +531,8 @@ int main(int argc, char* argv[]) ...@@ -529,8 +531,8 @@ int main(int argc, char* argv[])
make_tuple(conv_dilation_h, conv_dilation_w), make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w), make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w), make_tuple(in_right_pad_h, in_right_pad_w),
layout, activ_type,
ActivType_t::sigmoid); layout);
check_error(out_host, out_device); check_error(out_host, out_device);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
template <typename TInWei, template <typename TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
ck::index_t OutVectorSize, ck::index_t OutVectorSize,
ck::index_t activ_type,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
typename InLengths, typename InLengths,
...@@ -152,6 +153,7 @@ void device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -152,6 +153,7 @@ void device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<activ_type>{},
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
typedef enum
{
passthrough = 0,
relu,
sigmoid
} ActivType_t;
template <typename TIn, template <typename TIn,
typename TWei, typename TWei,
typename TOut, typename TOut,
...@@ -96,13 +89,13 @@ void host_direct_convolution(const Tensor<TIn>& in, ...@@ -96,13 +89,13 @@ void host_direct_convolution(const Tensor<TIn>& in,
} }
template <typename T> template <typename T>
inline auto activ(T v, const ActivType_t activ_type) inline auto activ(T v, const ck::index_t activ_type)
{ {
switch(activ_type) switch(activ_type)
{ {
case passthrough: return v; case 0: return v;
case relu: return (v >= 0 ? v : 0); case 1: return (v >= 0 ? v : 0);
case sigmoid: return (1 / (1 + exp(-v))); case 2: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break; default: throw std::runtime_error("unsupported activ type"); break;
} }
} }
...@@ -121,8 +114,8 @@ void host_direct_convolution_activ(const Tensor<TIn>& in, ...@@ -121,8 +114,8 @@ void host_direct_convolution_activ(const Tensor<TIn>& in,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
const ConvTensorLayout layout = ConvTensorLayout::NCHW, const ck::index_t activ_type,
const ActivType_t activ_type = ActivType_t::passthrough) const ConvTensorLayout layout = ConvTensorLayout::NCHW)
{ {
using namespace ck; using namespace ck;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment