Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::ActivTypeEnum_t activ_type, ck::ActivTypeEnum activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename AddLengths, typename AddLengths,
......
...@@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( ...@@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH; const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW; const auto XTilde = ConvStrideW / GcdStrideDilationW;
float ave_time = 0; float ave_time = 0;
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{ {
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{ {
const auto descs = const auto descs =
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
...@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
i_ytilda, i_ytilde,
i_xtilda, i_xtilde,
Number<GemmK1>{}); Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
...@@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk ...@@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc), decltype(in_gemmm_gemmn_grid_desc),
......
...@@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk ...@@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -367,7 +367,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_ ...@@ -367,7 +367,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
TIn, TIn,
TAcc, TAcc,
TWei, TWei,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
......
...@@ -138,7 +138,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( ...@@ -138,7 +138,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(wei_gemmk_gemmm_grid_desc), decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc), decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -202,7 +202,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( ...@@ -202,7 +202,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -167,7 +167,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -167,7 +167,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -522,7 +522,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( ...@@ -522,7 +522,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc), decltype(out_gemmm_gemmn_grid_desc),
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::ActivTypeEnum_t activ_type, ck::ActivTypeEnum activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename OutLengths, typename OutLengths,
......
...@@ -182,7 +182,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( ...@@ -182,7 +182,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
TInWei, TInWei,
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1), decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1), decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1), decltype(out_grid_desc_gm0_gm1_gn0_gn1),
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::ActivTypeEnum_t activ_type, ck::ActivTypeEnum activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename MaxLengths, typename MaxLengths,
......
...@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m, ...@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m, ...@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m, ...@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m, ...@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment