Commit da207144 authored by Jing Zhang's avatar Jing Zhang
Browse files

test

parent 26c42b94
...@@ -315,6 +315,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -315,6 +315,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
static constexpr auto NPerBlock = I1; static constexpr auto NPerBlock = I1;
static constexpr FloatAcc alpha = 0.30000001192092896;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
...@@ -995,28 +997,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -995,28 +997,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
} }
} }
// activ
if constexpr(activ_type > 0)
{
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
}
else if constexpr(activ_type == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
// Bias
if constexpr(bias_type == 1) if constexpr(bias_type == 1)
{ {
constexpr auto bias_k0_k1_thread_desc = constexpr auto bias_k0_k1_thread_desc =
...@@ -1068,6 +1050,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -1068,6 +1050,28 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
#endif #endif
} }
// Activ
if constexpr(activ_type > 0)
{
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
c_thread_buf(i) =
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
}
else if constexpr(activ_type == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
#if 1 #if 1
// Output // Output
if constexpr(out_type == 1) if constexpr(out_type == 1)
......
...@@ -303,7 +303,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -303,7 +303,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type, activ_type,
0, // bias_type 1, // bias_type
1, // out_type 1, // out_type
0 // add_type 0 // add_type
>; >;
......
...@@ -263,7 +263,7 @@ int main(int argc, char* argv[]) ...@@ -263,7 +263,7 @@ int main(int argc, char* argv[])
in_right_pads_dev); in_right_pads_dev);
}; };
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::None; constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu;
#if USE_CONV_FWD_V5R1_NCHWC #if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWC) if(algo == ConvForwardAlgo::V5R1NCHWC)
......
...@@ -93,9 +93,9 @@ int main(int argc, char* argv[]) ...@@ -93,9 +93,9 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[4]); const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[5]);
constexpr index_t activ_type = 0; constexpr index_t activ_type = 1;
#if 0 #if 1
constexpr auto N = Number<1>{}; constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{}; constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{}; constexpr auto Wi = Number<1920>{};
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
template <typename T> template <typename T>
inline auto activ(T v, const ck::index_t activ_type) inline auto activ(T v, const ck::index_t activ_type)
{ {
const T alpha = 0.30000001192092896;
switch(activ_type) switch(activ_type)
{ {
case 0: return v; case 0: return v;
case 1: return (v >= 0 ? v : 0); case 1: return (v >= 0 ? v : alpha * v);
case 2: return (1 / (1 + exp(-v))); case 2: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break; default: throw std::runtime_error("unsupported activ type"); break;
} }
...@@ -273,7 +274,8 @@ void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in, ...@@ -273,7 +274,8 @@ void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in,
} }
} }
v = activ(v, activ_type) + bias(k0, k1); v += bias(k0, k1);
v = activ(v, activ_type);
out_host(n, k0, ho, wo, k1) = v; out_host(n, k0, ho, wo, k1) = v;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment