Commit 1b79fce9 authored by Jing Zhang's avatar Jing Zhang
Browse files

create seperate fusion fun

parent 8e897da7
......@@ -152,10 +152,9 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid,
GridwiseGemm::ConvBiasActivResizeAddRun(p_a_grid,
p_b_grid,
p_bias_grid,
nullptr,
p_d_grid,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
......@@ -198,7 +197,7 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid,
GridwiseGemm::ConvBiasActivMaxpoolRun(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
......@@ -241,16 +240,14 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid,
GridwiseGemm::ConvBiasActiv(p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
nullptr,
p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
}
......@@ -296,13 +293,10 @@ template <index_t BlockSize,
typename CGlobalStepHacks,
typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0,
index_t bias_type = 0,
index_t out_type = 1,
index_t add_type = 0>
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -318,6 +312,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr FloatAcc alpha = 0.3;
static constexpr auto activ_type = I1;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
......@@ -539,23 +535,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
if constexpr(add_type == 1)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(d_k_n_hx_wx_grid_desc);
}
else if constexpr(add_type == 2)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
}
else
{
return MakeCK0K1NH0H1H2W0W1W2GridDescriptor(d_k_n_hx_wx_grid_desc);
}
}
__host__ __device__ static constexpr auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{
......@@ -584,18 +563,19 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return c_blockid_to_k_n_ho_wo_block_cluster_adaptor;
}
using AGridDesc_E0_E1_K0_K1_E2 =
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
// using AGridDesc_E0_E1_K0_K1_E2 =
// decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
// using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
// decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
// using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
// decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
// using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
// decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
template <typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{
......@@ -728,18 +708,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
});
}
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2>
__device__ static void Activation(CThreadBuff& c_thread_buf, const CThreadDesc_K1_N_H2_W2&)
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, index_t activ_type_>
__device__ static void
Activation(CThreadBuff& c_thread_buf, const CThreadDesc_K1_N_H2_W2&, Number<activ_type_>)
{
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
if constexpr(activ_type_ == 1)
{
c_thread_buf(i) =
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
}
else if constexpr(activ_type == 2)
else if constexpr(activ_type_ == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
......@@ -1024,6 +1004,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename CThreadBuff,
typename CBlockIndex,
typename CThreadIndex,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CThreadDesc_K1_N_H2_W2,
bool HasMainE0BlockLoop>
__device__ static void
......@@ -1394,9 +1376,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}
}
template <bool HasMainE0BlockLoop>
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_global,
Conv(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
......@@ -1437,6 +1424,69 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
......@@ -1450,7 +1500,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
if constexpr(bias_type > 0)
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
......@@ -1459,26 +1508,94 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
if constexpr(activ_type > 0)
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc);
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
if constexpr(out_type > 0)
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
if constexpr(add_type == 1)
// Resize_Add
ResizeAdd(c_thread_buf,
d_global_buf,
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void ConvBiasActivMaxpoolRun(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
else if constexpr(add_type == 2)
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
// MaxPool
MaxPool(c_thread_buf,
d_global_buf,
......@@ -1487,6 +1604,83 @@ struct GridwiseGemmDlops_km_kn_mn_v3
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void ConvBiasActivResizeAddRun(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Resize_Add
ResizeAdd(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
}
};
} // namespace ck
......
......@@ -336,12 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type,
1, // bias_type
0, // out_type
1 // add_type
>;
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
......@@ -350,7 +345,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(
d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
......@@ -301,12 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type,
1, // bias_type
1, // out_type
0 // add_type
>;
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
......
......@@ -340,12 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type,
1, // bias_type
1, // out_type
2 // add_type
>;
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
......@@ -354,7 +349,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hx_wx_grid_desc);
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
......@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]);
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::Sigmoid;
// constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu;
// constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::Sigmoid;
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu;
#if 0
constexpr auto N = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment