Commit 059e1c96 authored by Chao Liu's avatar Chao Liu
Browse files

tweak

parent fe1a31b0
......@@ -20,7 +20,8 @@ template <typename... In,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
index_t GemmK1Value,
typename GemmKBatchType>
__host__ __device__ constexpr auto
transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
......@@ -30,7 +31,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
Number<GemmK1Value>,
GemmKBatchType GemmKBatch)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -64,10 +66,11 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmK = N * Ho * Wo;
const auto GemmK0 = GemmK / GemmK1;
const auto GemmM = Y * X * C;
const auto GemmN = K;
const auto GemmKTotal = N * Ho * Wo;
const auto GemmK = GemmKTotal / GemmKBatch;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
......@@ -88,30 +91,30 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
const auto in_gemmktotal_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: output tensor
const auto out_gemmk_gemmn_grid_desc =
const auto out_gemmktotal_gemmn_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
const auto out_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(out_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmktotal_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
......@@ -120,8 +123,8 @@ transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc);
}
......
......@@ -97,8 +97,8 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
......@@ -171,33 +171,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
index_t KBatch)
CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) &&
(NPerBlock % (NRepeat * NPerXDL)) == 0,
"Invalid tuning param!");
if(K0 % (KBatch * KPerBlock) != 0)
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) &&
K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) &&
K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) &&
KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
}
......@@ -212,42 +208,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return grid_size;
}
__host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, index_t KBatch)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(M),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return a_b_k0_m_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, index_t KBatch)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
assert(K0 % KBatch == 0);
const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor(
b_k0_n_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0 / KBatch)),
make_pass_through_transform(N),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return b_b_k0_n_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
......@@ -300,8 +260,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, 1));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, 1));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1));
......
......@@ -13,7 +13,8 @@ template <typename TInWei,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
typename InRightPads,
typename GemmKBatchType>
void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
......@@ -25,7 +26,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
const Tensor<TInWei>& in_n_hi_wi_c,
Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t KBatch,
GemmKBatchType GemmKBatch,
ck::index_t nrepeat)
{
using namespace ck;
......@@ -115,32 +116,33 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
Number<GemmK1>{},
GemmKBatch);
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto wei_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
......@@ -160,15 +162,16 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0>{};
constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0>{};
constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0>{};
std::function<void()> clear_weight = [&wei_k_y_x_c_device_buf, &wei_k_y_x_c]() {
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r4<
......@@ -177,8 +180,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
TAcc,
TOut,
InMemoryDataOperationEnum_t::AtomicAdd,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc),
decltype(wei_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
......@@ -207,24 +210,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc,
KBatch,
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat,
&clear_weight);
......
......@@ -11,8 +11,8 @@ template <ck::index_t BlockSize,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CMNGridDesc,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
......@@ -50,10 +50,9 @@ template <ck::index_t BlockSize,
__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc,
ck::index_t KBatch,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
......@@ -68,6 +67,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4<BlockSize,
......@@ -75,8 +75,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AK0MK1GridDesc,
BK0NK1GridDesc,
ABK0MK1GridDesc,
BBK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
......@@ -113,25 +113,21 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat>;
{
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I1) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I2) << ", "
<< a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I1) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I2) << ", "
<< b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, KBatch);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, KBatch);
if(!GridwiseGemm::CheckValidity(
a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, KBatch))
if(!GridwiseGemm::CheckValidity(a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
......@@ -140,10 +136,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0);
const auto c_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, KBatch);
......@@ -153,6 +149,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
{
std::cout << "gridSize : " << grid_size << std::endl;
}
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment