Commit 1b79fce9 authored by Jing Zhang's avatar Jing Zhang
Browse files

create seperate fusion fun

parent 8e897da7
...@@ -152,18 +152,17 @@ __global__ void ...@@ -152,18 +152,17 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor = constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{}; CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid, GridwiseGemm::ConvBiasActivResizeAddRun(p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
nullptr, p_d_grid,
p_d_grid, p_shared_block,
p_shared_block, a_e0_e1_k0_k1_e2_grid_desc,
a_e0_e1_k0_k1_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, c_blockid_to_k_n_h_w_block_cluster_adaptor,
c_blockid_to_k_n_h_w_block_cluster_adaptor, integral_constant<bool, HasMainE0BlockLoop>{});
integral_constant<bool, HasMainE0BlockLoop>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -198,18 +197,18 @@ __global__ void ...@@ -198,18 +197,18 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor = constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{}; CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid, GridwiseGemm::ConvBiasActivMaxpoolRun(p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid, p_c_grid,
p_d_grid, p_d_grid,
p_shared_block, p_shared_block,
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
template <typename GridwiseGemm, template <typename GridwiseGemm,
...@@ -241,18 +240,16 @@ __global__ void ...@@ -241,18 +240,16 @@ __global__ void
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor = constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{}; CBlockIdToBlockClusterAdaptor_K_N_H_W{};
GridwiseGemm::Run(p_a_grid, GridwiseGemm::ConvBiasActiv(p_a_grid,
p_b_grid, p_b_grid,
p_bias_grid, p_bias_grid,
p_c_grid, p_c_grid,
nullptr, p_shared_block,
p_shared_block, a_e0_e1_k0_k1_e2_grid_desc,
a_e0_e1_k0_k1_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_blockid_to_k_n_h_w_block_cluster_adaptor,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, integral_constant<bool, HasMainE0BlockLoop>{});
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
} }
#endif #endif
...@@ -296,13 +293,10 @@ template <index_t BlockSize, ...@@ -296,13 +293,10 @@ template <index_t BlockSize,
typename CGlobalStepHacks, typename CGlobalStepHacks,
typename DGlobalStepHacks, typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks>
index_t activ_type = 0,
index_t bias_type = 0,
index_t out_type = 1,
index_t add_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
...@@ -318,6 +312,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -318,6 +312,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr FloatAcc alpha = 0.3; static constexpr FloatAcc alpha = 0.3;
static constexpr auto activ_type = I1;
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
...@@ -539,23 +535,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -539,23 +535,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
} }
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
if constexpr(add_type == 1)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(d_k_n_hx_wx_grid_desc);
}
else if constexpr(add_type == 2)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
}
else
{
return MakeCK0K1NH0H1H2W0W1W2GridDescriptor(d_k_n_hx_wx_grid_desc);
}
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc)
{ {
...@@ -584,18 +563,19 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -584,18 +563,19 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return c_blockid_to_k_n_ho_wo_block_cluster_adaptor; return c_blockid_to_k_n_ho_wo_block_cluster_adaptor;
} }
using AGridDesc_E0_E1_K0_K1_E2 = // using AGridDesc_E0_E1_K0_K1_E2 =
decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{}));
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W = using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
template <typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)
{ {
...@@ -728,29 +708,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -728,29 +708,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
}); });
} }
template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2> template <typename CThreadBuff, typename CThreadDesc_K1_N_H2_W2, index_t activ_type_>
__device__ static void Activation(CThreadBuff& c_thread_buf, const CThreadDesc_K1_N_H2_W2&) __device__ static void
Activation(CThreadBuff& c_thread_buf, const CThreadDesc_K1_N_H2_W2&, Number<activ_type_>)
{ {
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{};
static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1) if constexpr(activ_type_ == 1)
{ {
c_thread_buf(i) = c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i];
c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; }
} else if constexpr(activ_type_ == 2)
else if constexpr(activ_type == 2) {
{ FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \ asm volatile("\n \
v_rcp_f32 %0, %1 \n" v_rcp_f32 %0, %1 \n"
: "=v"(x) : "=v"(x)
: "0"(x)); : "0"(x));
c_thread_buf(i) = x; c_thread_buf(i) = x;
} }
}); });
} }
template <typename CThreadBuff, template <typename CThreadBuff,
...@@ -1024,6 +1004,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1024,6 +1004,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
typename CThreadBuff, typename CThreadBuff,
typename CBlockIndex, typename CBlockIndex,
typename CThreadIndex, typename CThreadIndex,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CThreadDesc_K1_N_H2_W2, typename CThreadDesc_K1_N_H2_W2,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__device__ static void __device__ static void
...@@ -1394,9 +1376,156 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1394,9 +1376,156 @@ struct GridwiseGemmDlops_km_kn_mn_v3
} }
} }
template <bool HasMainE0BlockLoop> template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_global, Conv(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void ConvBiasActiv(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Output
WriteOut(c_thread_buf,
c_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
}
template <typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__device__ static void ConvBiasActivMaxpoolRun(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global, const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -1450,42 +1579,107 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -1450,42 +1579,107 @@ struct GridwiseGemmDlops_km_kn_mn_v3
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
// Bias // Bias
if constexpr(bias_type > 0) BiasOp(bias_global_buf,
BiasOp(bias_global_buf, c_thread_buf,
c_thread_buf, c_k_n_h_w_block_cluster_idx,
c_k_n_h_w_block_cluster_idx, c_thread_mtx_index,
c_thread_mtx_index, bias_k0_k1_grid_desc,
bias_k0_k1_grid_desc, c_k1_n_h2_w2_thread_gemm_desc);
c_k1_n_h2_w2_thread_gemm_desc);
// Activ // Activ
if constexpr(activ_type > 0) Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc);
// Output // Output
if constexpr(out_type > 0) WriteOut(c_thread_buf,
WriteOut(c_thread_buf, c_global_buf,
c_global_buf, c_k_n_h_w_block_cluster_idx,
c_k_n_h_w_block_cluster_idx, c_thread_mtx_index,
c_thread_mtx_index, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
// MaxPool
if constexpr(add_type == 1) MaxPool(c_thread_buf,
// Resize_Add d_global_buf,
ResizeAdd(c_thread_buf, c_k_n_h_w_block_cluster_idx,
d_global_buf, c_thread_mtx_index,
c_k_n_h_w_block_cluster_idx, c_k1_n_h2_w2_thread_gemm_desc,
c_thread_mtx_index, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
c_k1_n_h2_w2_thread_gemm_desc, }
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
else if constexpr(add_type == 2) template <typename AGridDesc_E0_E1_K0_K1_E2,
// MaxPool typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
MaxPool(c_thread_buf, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
d_global_buf, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
c_k_n_h_w_block_cluster_idx, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
c_thread_mtx_index, bool HasMainE0BlockLoop>
c_k1_n_h2_w2_thread_gemm_desc, __device__ static void ConvBiasActivResizeAddRun(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_d_global,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor();
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(),
true>
c_thread_buf;
const auto c_k_n_h_w_block_cluster_idx =
GetCBlockIndex(c_blockid_to_k_n_h_w_block_cluster_adaptor);
const auto c_thread_mtx_index = GetCThreadIndex();
// GemmOp
GemmOp(a_global_buf,
b_global_buf,
c_thread_buf,
p_shared_block,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc,
integral_constant<bool, HasMainE0BlockLoop>{});
// Bias
BiasOp(bias_global_buf,
c_thread_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
bias_k0_k1_grid_desc,
c_k1_n_h2_w2_thread_gemm_desc);
// Activ
Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type);
// Resize_Add
ResizeAdd(c_thread_buf,
d_global_buf,
c_k_n_h_w_block_cluster_idx,
c_thread_mtx_index,
c_k1_n_h2_w2_thread_gemm_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
} }
}; };
......
...@@ -336,12 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -336,12 +336,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks), decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
activ_type,
1, // bias_type
0, // out_type
1 // add_type
>;
const auto a_e0_e1_k0_k1_e2_grid_desc = const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
...@@ -350,7 +345,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -350,7 +345,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hopx2_wopx2_grid_desc); GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(
d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
...@@ -301,12 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -301,12 +301,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
activ_type,
1, // bias_type
1, // out_type
0 // add_type
>;
const auto a_e0_e1_k0_k1_e2_grid_desc = const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
......
...@@ -340,12 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -340,12 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks), decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks), decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack), decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack)>;
activ_type,
1, // bias_type
1, // out_type
2 // add_type
>;
const auto a_e0_e1_k0_k1_e2_grid_desc = const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc); GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
...@@ -354,7 +349,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -354,7 +349,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hx_wx_grid_desc); GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
...@@ -92,8 +92,8 @@ int main(int argc, char* argv[]) ...@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[4]); const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[5]);
constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::Sigmoid; // constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::Sigmoid;
// constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu; constexpr ck::ActivTypeEnum_t activ_type = ActivTypeEnum_t::LeakyRelu;
#if 0 #if 0
constexpr auto N = Number<1>{}; constexpr auto N = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment