Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
......@@ -6,7 +6,7 @@
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
ck::ActivTypeEnum activ_type,
typename InLengths,
typename WeiLengths,
typename AddLengths,
......
......@@ -231,7 +231,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
......
......@@ -303,14 +303,14 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
float ave_time = 0;
for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda)
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda)
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
......@@ -321,8 +321,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
conv_dilations,
in_left_pads,
in_right_pads,
i_ytilda,
i_xtilda,
i_ytilde,
i_xtilde,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
......@@ -338,7 +338,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
......
......@@ -307,7 +307,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
......
......@@ -171,7 +171,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......
......@@ -168,7 +168,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......
......@@ -200,7 +200,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd,
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......
......@@ -199,7 +199,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......
......@@ -367,7 +367,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_
TIn,
TAcc,
TWei,
InMemoryDataOperationEnum_t::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd,
decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
......
......@@ -138,7 +138,7 @@ void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
......
......@@ -202,7 +202,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
......
......@@ -167,7 +167,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
......
......@@ -522,7 +522,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
......
......@@ -6,7 +6,7 @@
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
ck::ActivTypeEnum activ_type,
typename InLengths,
typename WeiLengths,
typename OutLengths,
......
......@@ -182,7 +182,7 @@ void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
......
......@@ -6,7 +6,7 @@
template <typename TInWei,
typename TAcc,
typename TOut,
ck::ActivTypeEnum_t activ_type,
ck::ActivTypeEnum activ_type,
typename InLengths,
typename WeiLengths,
typename MaxLengths,
......
......@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_kn_mn(const Tensor<ABType>& a_k_m,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
......
......@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_kn_nm(const Tensor<ABType>& a_k_m,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
......
......@@ -398,7 +398,7 @@ void device_gemm_xdlops_km_nk_mn(const Tensor<ABType>& a_k_m,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
......
......@@ -202,7 +202,7 @@ void device_gemm_xdlops_km_nk_nm(const Tensor<ABType>& a_k_m,
ABType,
AccType,
CType,
InMemoryDataOperationEnum_t::Set,
InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment