"...composable_kernel_rocm.git" did not exist on "9c54eaab04e6db605dc86f1d1ab16bd04f51fc89"
Commit af84fba3 authored by Jing Zhang's avatar Jing Zhang
Browse files

use activ_enum

parent 3a3136ce
...@@ -143,7 +143,7 @@ template <index_t BlockSize, ...@@ -143,7 +143,7 @@ template <index_t BlockSize,
typename CGlobalStepHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0> ActivTypeEnum_t activ_type = ActivTypeEnum_t::None>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -159,6 +159,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -159,6 +159,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto NPerBlock = I1; static constexpr auto NPerBlock = I1;
static constexpr FloatC alpha = 0.3;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
...@@ -729,7 +731,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -729,7 +731,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1) if constexpr(activ_type == 1)
{ {
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0; c_thread_buf(i) =
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
} }
else if constexpr(activ_type == 2) else if constexpr(activ_type == 2)
{ {
......
...@@ -127,6 +127,13 @@ enum InMemoryDataOperationEnum_t ...@@ -127,6 +127,13 @@ enum InMemoryDataOperationEnum_t
AtomicAdd AtomicAdd
}; };
enum ActivTypeEnum_t
{
None = 0,
LeakyRelu,
Sigmoid
};
// index type // index type
using index_t = int32_t; using index_t = int32_t;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::index_t activ_type, ck::ActivTypeEnum_t activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename OutLengths, typename OutLengths,
......
...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize, ...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K, ck::index_t CThreadTransferDstScalarPerVector_K,
ck::index_t activ_type> ck::ActivTypeEnum_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
{ {
template <typename... Wei, template <typename... Wei,
...@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
//const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
......
...@@ -250,7 +250,7 @@ int main(int argc, char* argv[]) ...@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
in_right_pads_dev); in_right_pads_dev);
}; };
constexpr index_t activ_type = 0; constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::None;
#if USE_CONV_FWD_V5R1_NCHWC #if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWC) if(algo == ConvForwardAlgo::V5R1NCHWC)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment