Commit 149296c0 authored by ltqin's avatar ltqin
Browse files

add MakeCGM0N0M1N1M2M3M4N2GridDescriptor

parent 973978aa
...@@ -90,15 +90,16 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad( ...@@ -90,15 +90,16 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{})); make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto in_gemmg_gemmk_gemmm_grid_desc = const auto in_gemmg_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
transform_tensor_descriptor(in_g_n_y_ho_x_wo_c_grid_desc, in_g_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_pass_through_transform(G), make_tuple(make_pass_through_transform(G),
make_merge_transform(make_tuple(Y, X, C)), make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<0>{}, Sequence<2, 4, 6>{}, Sequence<1, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<2, 4, 6>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc = const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmg_gemmk_gemmm_grid_desc, transform_tensor_descriptor(in_gemmg_gemmk_gemmm_grid_desc,
...@@ -112,7 +113,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad( ...@@ -112,7 +113,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
const auto wei_gemmg_gemmk_gemmn_grid_desc = transform_tensor_descriptor( const auto wei_gemmg_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(G, K, Y * X * C)), make_naive_tensor_descriptor_packed(make_tuple(G, K, Y * X * C)),
make_tuple(make_pass_through_transform(G), make_tuple(make_pass_through_transform(G),
make_pass_through_transform(K), make_pass_through_transform(K),
make_pass_through_transform(Y * X * C)), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
...@@ -129,7 +130,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad( ...@@ -129,7 +130,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(
const auto out_gemmg_gemmm_gemmn_grid_desc = transform_tensor_descriptor( const auto out_gemmg_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, G, K)), make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, G, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_tuple(make_pass_through_transform(N * Ho * Wo),
make_pass_through_transform(G), make_pass_through_transform(G),
make_pass_through_transform(K)), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{})); make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
......
...@@ -158,6 +158,22 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -158,6 +158,22 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
} }
template <typename CGMNGridDesc>
__host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc)
{
const auto G = c_g_m_n_grid_desc.GetLength(I0);
const auto c_g_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
c_g_m_n_grid_desc,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCGM0N0M1N1M2M3M4N2Descriptor(c_g_m0_n0_m1_n1_m2_n2_grid_desc);
}
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
{ {
return transform_tensor_descriptor( return transform_tensor_descriptor(
......
...@@ -727,6 +727,43 @@ struct XdlopsGemm ...@@ -727,6 +727,43 @@ struct XdlopsGemm
Sequence<7>{})); Sequence<7>{}));
} }
template <typename CGM0N0M1N1M2N2Desc>
__host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2Descriptor(const CGM0N0M1N1M2N2Desc& c_g_m0_n0_m1_n1_m2_n2_desc)
{
const auto G = c_g_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
const auto M0 = c_g_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
const auto N0 = c_g_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
const auto M1 = c_g_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
const auto N1 = c_g_m0_n0_m1_n1_m2_n2_desc.GetLength(I4);
return transform_tensor_descriptor(
c_g_m0_n0_m1_n1_m2_n2_desc,
make_tuple(make_pass_through_transform(G),
make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{},
Sequence<8>{}));
}
__device__ static constexpr index_t GetRegSizePerXdlops() __device__ static constexpr index_t GetRegSizePerXdlops()
{ {
return MPerXdlops * NPerXdlops / mfma_instr.wave_size; return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
......
...@@ -221,28 +221,28 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk( ...@@ -221,28 +221,28 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk(
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(in_n_hi_wi_g_c_desc, transform_forward_convolution_into_gemm_v4r4r4_nhwgc_gkyxc_nhwgk_pad(in_n_hi_wi_g_c_desc,
wei_g_k_y_x_c_desc, wei_g_k_y_x_c_desc,
out_n_ho_wo_g_k_desc, out_n_ho_wo_g_k_desc,
conv_strides, conv_strides,
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
Number<GemmK1>{}); Number<GemmK1>{});
const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; const auto in_gemmg_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; const auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmg_gemmm_gemmn_grid_desc = descs[I2]; const auto out_gemmg_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix // HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks = constexpr auto in_gemmg_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmG make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmG
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto wei_gemmg_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmG
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp" #include "gridwise_gemm_xdlops_v3r1.hpp"
template <ck::index_t BlockSize, template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
...@@ -22,16 +22,16 @@ template <ck::index_t BlockSize, ...@@ -22,16 +22,16 @@ template <ck::index_t BlockSize,
ck::index_t K1, ck::index_t K1,
ck::index_t MRepeat, ck::index_t MRepeat,
ck::index_t NRepeat, ck::index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1, typename ABlockTransferThreadSliceLengths_G_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1, typename ABlockTransferThreadClusterLengths_G_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder, typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim, ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector, ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1, ck::index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun, bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1, typename BBlockTransferThreadSliceLengths_G_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterLengths_G_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder, typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim, ck::index_t BBlockTransferSrcVectorDim,
...@@ -50,9 +50,9 @@ template <ck::index_t BlockSize, ...@@ -50,9 +50,9 @@ template <ck::index_t BlockSize,
__host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const AK0MK1GridDesc& a_g_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const BK0NK1GridDesc& b_g_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_g_m_n_grid_desc,
AGridStepHacks, AGridStepHacks,
BGridStepHacks, BGridStepHacks,
CGridStepHacks, CGridStepHacks,
...@@ -66,9 +66,10 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -66,9 +66,10 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
/* using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize, GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
...@@ -84,16 +85,16 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -84,16 +85,16 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadSliceLengths_G_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadSliceLengths_G_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
...@@ -111,84 +112,88 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -111,84 +112,88 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
CAccessOrderMRepeatNRepeat>; CAccessOrderMRepeatNRepeat>;
{ {
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_g_k0_m_k1_grid_desc{" << a_g_k0_m_k1_grid_desc.GetLength(I0) << ", "
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) << a_g_k0_m_k1_grid_desc.GetLength(I1) << ", "
<< a_g_k0_m_k1_grid_desc.GetLength(I2) << ", "
<< a_g_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_g_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_g_k0_n_k1_grid_desc.GetLength(I1) << ", "
<< b_g_k0_n_k1_grid_desc.GetLength(I2) << ", "
<< b_g_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_g_m_n_grid_desc.GetLength(I0) << ", "
<< c_g_m_n_grid_desc.GetLength(I1) << ", " << c_g_m_n_grid_desc.GetLength(I2)
<< "}" << std::endl; << "}" << std::endl;
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
<< "}" << std::endl;
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(
a_g_k0_m_k1_grid_desc, b_g_k0_n_k1_grid_desc, c_g_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_gemmg_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); /* using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>, remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>, remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time = launch_and_time_kernel(
kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time;*/ return ave_time;*/
return 0.0; return 0.0;
} }
#endif #endif
...@@ -36,8 +36,8 @@ enum ConvForwardAlgo ...@@ -36,8 +36,8 @@ enum ConvForwardAlgo
V6R1NCHW, // 2 V6R1NCHW, // 2
V5R1NCHW, // 3 V5R1NCHW, // 3
V4R4R2XDLNCHW, // 4 V4R4R2XDLNCHW, // 4
V4R4R4XDLNHWC, // 5 V4R4R4XDLNHWC, // 5
V4R4R4XDLNHWGC // 6 V4R4R4XDLNHWGC // 6
}; };
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -68,10 +68,10 @@ int main(int argc, char* argv[]) ...@@ -68,10 +68,10 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
index_t G = 1; index_t G = 1;
const index_t N = std::stoi(argv[7]); const index_t N = std::stoi(argv[7]);
index_t K = std::stoi(argv[8]); index_t K = std::stoi(argv[8]);
index_t C = std::stoi(argv[9]); index_t C = std::stoi(argv[9]);
const index_t Y = std::stoi(argv[10]); const index_t Y = std::stoi(argv[10]);
const index_t X = std::stoi(argv[11]); const index_t X = std::stoi(argv[11]);
const index_t Hi = std::stoi(argv[12]); const index_t Hi = std::stoi(argv[12]);
...@@ -85,12 +85,12 @@ int main(int argc, char* argv[]) ...@@ -85,12 +85,12 @@ int main(int argc, char* argv[])
const index_t in_left_pad_w = std::stoi(argv[19]); const index_t in_left_pad_w = std::stoi(argv[19]);
const index_t in_right_pad_h = std::stoi(argv[20]); const index_t in_right_pad_h = std::stoi(argv[20]);
const index_t in_right_pad_w = std::stoi(argv[21]); const index_t in_right_pad_w = std::stoi(argv[21]);
if (argc == 23){ if(argc == 23)
G = std::stoi(argv[22]); {
K = K / G; G = std::stoi(argv[22]);
C = C / G; K = K / G;
C = C / G;
} }
const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1;
...@@ -480,8 +480,8 @@ int main(int argc, char* argv[]) ...@@ -480,8 +480,8 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nhwgc(); const auto tmp = f_make_for_device_nhwgc();
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk<in_data_t, device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwgc_gkyxc_nhwgk<in_data_t,
acc_data_t, acc_data_t,
out_data_t>( out_data_t>(
tmp[I0], tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment