Commit fe1a31b0 authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent 6c97007c
...@@ -103,11 +103,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( ...@@ -103,11 +103,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: output tensor // B: output tensor
const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( const auto out_gemmk_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk0_gemmn_gemmk1_grid_desc = const auto out_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
......
...@@ -132,8 +132,7 @@ template <index_t BlockSize, ...@@ -132,8 +132,7 @@ template <index_t BlockSize,
typename CGridStepHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat, bool CAccessOrderMRepeatNRepeat>
index_t KBatch>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -174,7 +173,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -174,7 +173,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc) const CMNGridDesc& c_m_n_grid_desc,
index_t KBatch)
{ {
// TODO: turn on this // TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
...@@ -188,6 +188,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -188,6 +188,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
(NPerBlock % (NRepeat * NPerXDL)) == 0, (NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
if(K0 % (KBatch * KPerBlock) != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
...@@ -197,7 +202,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -197,7 +202,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
...@@ -208,7 +213,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -208,7 +213,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, index_t KBatch)
{ {
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
...@@ -226,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -226,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, index_t KBatch)
{ {
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
...@@ -269,7 +274,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -269,7 +274,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
...@@ -295,10 +300,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -295,10 +300,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{})); using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, 1));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{})); using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, 1));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
...@@ -25,6 +25,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -25,6 +25,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const Tensor<TInWei>& in_n_hi_wi_c, const Tensor<TInWei>& in_n_hi_wi_c,
Tensor<TInWei>& wei_k_y_x_c, Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k, const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t KBatch,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -49,6 +50,34 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -49,6 +50,34 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [128, 256, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -76,8 +105,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -76,8 +105,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 32;
#endif #endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
...@@ -106,14 +133,14 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -106,14 +133,14 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 Sequence<0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 Sequence<0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 Sequence<0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
...@@ -137,7 +164,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -137,7 +164,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0>{};
std::function<void()> clear_weight = [&wei_k_y_x_c_device_buf, &wei_k_y_x_c]() { std::function<void()> clear_weight = [&wei_k_y_x_c_device_buf, &wei_k_y_x_c]() {
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
...@@ -163,42 +190,43 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_ ...@@ -163,42 +190,43 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
NRepeat, NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<0, 1, 3, 2>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 3, 2>, Sequence<0, 1, 2, 3>,
2, 2,
GemmABlockTransferSrcScalarPerVector_GemmM, GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1, GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 1, 3, 2>, Sequence<0, 1, 2, 3>,
Sequence<0, 1, 3, 2>, Sequence<0, 1, 2, 3>,
2, 2,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7, 6,
GemmCThreadTransferDstScalarPerVector, GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false, // CAccessOrderMRepeatNRepeat false // CAccessOrderMRepeatNRepeat
KBatch>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), >(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc, wei_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm_gemmk1_grid_step_hacks, KBatch,
out_gemmk0_gemmn_gemmk1_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
nrepeat, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
&clear_weight); nrepeat,
&clear_weight);
{ {
const auto N = out_n_ho_wo_k_lengths[I0]; const auto N = out_n_ho_wo_k_lengths[I0];
......
...@@ -46,14 +46,14 @@ template <ck::index_t BlockSize, ...@@ -46,14 +46,14 @@ template <ck::index_t BlockSize,
typename CGridStepHacks, typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks, typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks, typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat, bool CAccessOrderMRepeatNRepeat>
ck::index_t KBatch>
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
ck::index_t KBatch,
AGridStepHacks, AGridStepHacks,
BGridStepHacks, BGridStepHacks,
CGridStepHacks, CGridStepHacks,
...@@ -110,8 +110,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -110,8 +110,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CGridStepHacks, CGridStepHacks,
AGridMoveSliceWindowStepHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat, CAccessOrderMRepeatNRepeat>;
KBatch>;
{ {
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
...@@ -125,11 +124,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -125,11 +124,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc = GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc);
const auto b_b_k0_n_k1_grid_desc = GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc);
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, KBatch);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, KBatch);
if(!GridwiseGemm::CheckValidity(
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, KBatch))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
...@@ -142,11 +144,12 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -142,11 +144,12 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc); using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, KBatch);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch);
{ {
std::cout << "gridSize : " << grid_size << std::endl; std::cout << "gridSize : " << grid_size << std::endl;
} }
......
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1 #define USE_DYNAMIC_MODE 1
#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_NHWC 1 #define USE_CONV_WRW_V4R4R4_XDL_NHWC 0
#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 1 #define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0
#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1 #define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 1
enum ConvBackwardWeightAlgo enum ConvBackwardWeightAlgo
...@@ -45,7 +45,7 @@ int main(int argc, char* argv[]) ...@@ -45,7 +45,7 @@ int main(int argc, char* argv[])
#if USE_DYNAMIC_MODE #if USE_DYNAMIC_MODE
// dynamic mode // dynamic mode
if(argc != 22) if(argc != 23)
{ {
printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
...@@ -76,6 +76,8 @@ int main(int argc, char* argv[]) ...@@ -76,6 +76,8 @@ int main(int argc, char* argv[])
const index_t in_right_pad_h = std::stoi(argv[20]); const index_t in_right_pad_h = std::stoi(argv[20]);
const index_t in_right_pad_w = std::stoi(argv[21]); const index_t in_right_pad_w = std::stoi(argv[21]);
const index_t k_batch = std::stoi(argv[22]);
const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1;
...@@ -363,6 +365,7 @@ int main(int argc, char* argv[]) ...@@ -363,6 +365,7 @@ int main(int argc, char* argv[])
in, in,
wei_device, wei_device,
out, out,
k_batch,
nrepeat); nrepeat);
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment