Commit c982e753 authored by Jing Zhang's avatar Jing Zhang
Browse files

add make c into xldops-gemm

parent 62ebdfde
......@@ -40,14 +40,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
__device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDesc()
{
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
}
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
......@@ -131,39 +123,39 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
{
constexpr auto M2 = Number<CXdlopsLayout.M1()>{};
constexpr auto M3 = Number<CXdlopsLayout.N1()>{};
constexpr auto M4 = Number<CXdlopsLayout.M0()>{};
constexpr auto N2 = Number<CXdlopsLayout.N0()>{};
return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<M2>{},
Number<M3>{},
Number<M4>{},
Number<N2>{}));
constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
}
template <typename CMNGridDesc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
///\To-do: pass CGrid desc transform deep inside xdlops gemm
constexpr auto M2 = Number<CXdlopsLayout.M1()>{};
constexpr auto M3 = Number<CXdlopsLayout.N1()>{};
constexpr auto M4 = Number<CXdlopsLayout.M0()>{};
constexpr auto N2 = Number<CXdlopsLayout.N0()>{};
return transform_tensor_descriptor(
const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M2, M3, M4)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N2))),
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
}
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
......
......@@ -376,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDesc();
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr,
......
......@@ -690,24 +690,17 @@ struct XdlopsGemm
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
}
template <typename CM0N0M1N1M2N2GridDesc>
template <typename CM0N0M1N1M2N2Desc>
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc)
{
constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4);
constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5);
static_assert(N2 == mfma_type.num_threads_per_blk, "");
static_assert(
M2 == (mfma_type.num_groups_per_blk * mfma_type.num_output_blks * mfma_type.group_size),
"");
return transform_dynamic_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_grid_desc,
MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc)
{
const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
return transform_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_desc,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
......
......@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
......@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
......@@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment