Commit af84fba3 authored by Jing Zhang's avatar Jing Zhang
Browse files

use activ_enum

parent 3a3136ce
......@@ -143,7 +143,7 @@ template <index_t BlockSize,
typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0>
ActivTypeEnum_t activ_type = ActivTypeEnum_t::None>
struct GridwiseGemmDlops_km_kn_mn_v3
{
static constexpr auto I0 = Number<0>{};
......@@ -159,6 +159,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto NPerBlock = I1;
static constexpr FloatC alpha = 0.3;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
......@@ -729,7 +731,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
c_thread_buf(i) =
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
}
else if constexpr(activ_type == 2)
{
......
......@@ -127,6 +127,13 @@ enum InMemoryDataOperationEnum_t
AtomicAdd
};
enum ActivTypeEnum_t
{
None = 0,
LeakyRelu,
Sigmoid
};
// index type
using index_t = int32_t;
......
......@@ -6,7 +6,7 @@
template <typename TInWei,
typename TAcc,
typename TOut,
ck::index_t activ_type,
ck::ActivTypeEnum_t activ_type,
typename InLengths,
typename WeiLengths,
typename OutLengths,
......
......@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K,
ck::index_t activ_type>
ck::ActivTypeEnum_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
{
template <typename... Wei,
......@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
//const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
......
......@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
in_right_pads_dev);
};
constexpr index_t activ_type = 0;
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::None;
#if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment