Commit f328497e authored by Jing Zhang's avatar Jing Zhang
Browse files

fix

parent 3317bfe2
......@@ -8,6 +8,50 @@
namespace ck {
template <index_t GemmKPACK>
struct make_vectorized_WeiDesc_Xdlops
{
template <typename WeiDesc>
__device__ constexpr auto get(WeiDesc&)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
/* kpack comes from c*y*x */
static_assert((C * Y * X) % GemmKPACK == 0,
"C needs to be multiple of vectorized GemmKPACK");
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
constexpr auto wei_gemmm_gemmk_global_desc =
transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3),
make_tuple(PassThrough<K>{}, PassThrough<C * Y * X>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
constexpr auto wei_gemmm_gemmk_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<GemmK, GemmKPACK>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_gemmkpack_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<K>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return wei_gemmk_gemmm_gemmkpack_global_desc;
}
};
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
......@@ -65,11 +109,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t GemmM = K;
constexpr index_t GemmK = (C * Y * X) / GemmKPACK;
......@@ -80,13 +120,17 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
"wrong! cannot divide work evenly among block");
// sanity-check for vectorized memory load
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) &&
(X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated");
// input tensor
// global mem
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(
......@@ -124,20 +168,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower(
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
constexpr auto wei_gemmm_gemmk_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmk_gemmm_global_desc,
make_tuple(PassThrough<K>{}, UnMerge<Sequence<GemmK, GemmKPACK>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc =
make_vectorized_WeiDesc_Xdlops<GemmKPACK>{}.get(wei_k_c_y_x_global_desc);
constexpr auto wei_gemmk_gemmm_gemmkpack_global_desc = transform_tensor_descriptor(
wei_gemmm_gemmk_gemmkpack_global_desc,
make_tuple(PassThrough<GemmK>{}, PassThrough<K>{}, PassThrough<GemmKPACK>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment