Commit 059e1c96 authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent fe1a31b0
...@@ -20,7 +20,8 @@ template <typename... In, ...@@ -20,7 +20,8 @@ template <typename... In,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads, typename InRightPads,
index_t GemmK1Value> index_t GemmK1Value,
typename GemmKBatchType>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc, const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
...@@ -30,7 +31,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -30,7 +31,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads& in_right_pads, const InRightPads& in_right_pads,
Number<GemmK1Value>) Number<GemmK1Value>,
GemmKBatchType GemmKBatch)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -64,10 +66,11 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -64,10 +66,11 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const auto InRightPadH = in_right_pads[I0]; const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1]; const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C; const auto GemmM = Y * X * C;
const auto GemmN = K; const auto GemmN = K;
const auto GemmK = N * Ho * Wo; const auto GemmKTotal = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1; const auto GemmK = GemmKTotal / GemmKBatch;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor // A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
...@@ -88,30 +91,30 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -88,30 +91,30 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc = const auto in_gemmktotal_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc = const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, in_gemmktotal_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)), make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: output tensor // B: output tensor
const auto out_gemmk_gemmn_grid_desc = const auto out_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmk0_gemmn_gemmk1_grid_desc = const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, out_gemmktotal_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)), make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
...@@ -120,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -120,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmm_gemmn_grid_desc);
} }
......
...@@ -97,8 +97,8 @@ template <index_t BlockSize, ...@@ -97,8 +97,8 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
...@@ -171,33 +171,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -171,33 +171,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc)
index_t KBatch)
{ {
// TODO: turn on this // TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0, (NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
if(K0 % (KBatch * KPerBlock) != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
} }
...@@ -212,42 +208,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -212,42 +208,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, index_t KBatch)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(M),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return a_b_k0_m_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, index_t KBatch)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor(
b_k0_n_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(N),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return b_b_k0_n_k1_grid_desc;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
...@@ -300,8 +260,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -300,8 +260,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, 1));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, 1));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1)); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1));
......
...@@ -13,7 +13,8 @@ template <typename TInWei, ...@@ -13,7 +13,8 @@ template <typename TInWei,
typename ConvStrides, typename ConvStrides,
typename ConvDilations, typename ConvDilations,
typename InLeftPads, typename InLeftPads,
typename InRightPads> typename InRightPads,
typename GemmKBatchType>
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths, const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths, const WeiLengths& wei_k_y_x_c_lengths,
...@@ -25,7 +26,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -25,7 +26,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TInWei>& in_n_hi_wi_c,
Tensor<TInWei>& wei_k_y_x_c, Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t KBatch, GemmKBatchType GemmKBatch,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -115,32 +116,33 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -115,32 +116,33 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{},
GemmKBatch);
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2]; const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
...@@ -160,15 +162,16 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -160,15 +162,16 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0>{};
std::function<void()> clear_weight = [&wei_k_y_x_c_device_buf, &wei_k_y_x_c]() { std::function<void()> clear_weight = [&wei_k_y_x_c_device_buf, &wei_k_y_x_c]() {
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
}; };
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = driver_gemm_xdlops_v2r4< float ave_time = driver_gemm_xdlops_v2r4<
...@@ -177,8 +180,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -177,8 +180,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TAcc, TAcc,
TOut, TOut,
InMemoryDataOperationEnum_t::AtomicAdd, InMemoryDataOperationEnum_t::AtomicAdd,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc), decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
...@@ -207,24 +210,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -207,24 +210,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6, 6,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), >(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc, wei_gemmm_gemmn_grid_desc,
KBatch, in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat, nrepeat,
&clear_weight); &clear_weight);
......
...@@ -11,8 +11,8 @@ template <ck::index_t BlockSize, ...@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
ck::index_t MPerBlock, ck::index_t MPerBlock,
ck::index_t NPerBlock, ck::index_t NPerBlock,
...@@ -50,10 +50,9 @@ template <ck::index_t BlockSize, ...@@ -50,10 +50,9 @@ template <ck::index_t BlockSize,
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
ck::index_t KBatch,
AGridStepHacks, AGridStepHacks,
BGridStepHacks, BGridStepHacks,
CGridStepHacks, CGridStepHacks,
...@@ -68,6 +67,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -68,6 +67,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4<BlockSize,
...@@ -75,8 +75,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -75,8 +75,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AK0MK1GridDesc, ABK0MK1GridDesc,
BK0NK1GridDesc, BBK0NK1GridDesc,
CMNGridDesc, CMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
...@@ -113,25 +113,21 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -113,25 +113,21 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat>; CAccessOrderMRepeatNRepeat>;
{ {
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
<< "}" << std::endl; << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
<< "}" << std::endl; << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
const auto a_b_k0_m_k1_grid_desc = if(!GridwiseGemm::CheckValidity(a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc))
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, KBatch);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, KBatch);
if(!GridwiseGemm::CheckValidity(
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, KBatch))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
...@@ -140,10 +136,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -140,10 +136,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
const auto c_block_cluster_adaptor = const auto c_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, KBatch); GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, KBatch);
...@@ -153,6 +149,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -153,6 +149,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
{ {
std::cout << "gridSize : " << grid_size << std::endl; std::cout << "gridSize : " << grid_size << std::endl;
} }
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment