"...composable_kernel_rocm.git" did not exist on "c981f6d033d5af81aa3809c05beccab219aa8027"
Commit c982e753 authored by Jing Zhang's avatar Jing Zhang
Browse files

add make c into xldops-gemm

parent 62ebdfde
...@@ -40,14 +40,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -40,14 +40,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
__device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDesc()
{
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
}
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
{ {
const index_t thread_id = get_thread_local_1d_id(); const index_t thread_id = get_thread_local_1d_id();
...@@ -131,39 +123,39 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -131,39 +123,39 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
} }
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor()
{
constexpr auto M0 = Number<CXdlopsLayout.M1()>{};
constexpr auto M2 = Number<CXdlopsLayout.M0()>{};
return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, I1, M2, I1));
}
__host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor()
{ {
constexpr auto M2 = Number<CXdlopsLayout.M1()>{}; constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc =
constexpr auto M3 = Number<CXdlopsLayout.N1()>{}; make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
constexpr auto M4 = Number<CXdlopsLayout.M0()>{}; Number<NRepeat>{},
constexpr auto N2 = Number<CXdlopsLayout.N0()>{}; Number<MWaves>{},
Number<NWaves>{},
return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<MPerXDL>{},
Number<NRepeat>{}, Number<NPerXDL>{}));
Number<MWaves>{},
Number<NWaves>{}, return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc);
Number<M2>{},
Number<M3>{},
Number<M4>{},
Number<N2>{}));
} }
template <typename CMNGridDesc> template <typename CMNGridDesc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
///\To-do: pass CGrid desc transform deep inside xdlops gemm const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor(
constexpr auto M2 = Number<CXdlopsLayout.M1()>{};
constexpr auto M3 = Number<CXdlopsLayout.N1()>{};
constexpr auto M4 = Number<CXdlopsLayout.M0()>{};
constexpr auto N2 = Number<CXdlopsLayout.N0()>{};
return transform_tensor_descriptor(
c_m_n_grid_desc, c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M2, M3, M4)), make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N2))), make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
} }
__host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor()
......
...@@ -376,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -376,7 +376,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{})); make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDesc(); blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor();
constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
......
...@@ -690,24 +690,17 @@ struct XdlopsGemm ...@@ -690,24 +690,17 @@ struct XdlopsGemm
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
} }
template <typename CM0N0M1N1M2N2GridDesc> template <typename CM0N0M1N1M2N2Desc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CM0N0M1N1M2N2GridDesc& c_m0_n0_m1_n1_m2_n2_grid_desc) MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc)
{ {
constexpr auto M0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I0); const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I1); const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I2); const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2);
constexpr auto N1 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I3); const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I4);
constexpr auto N2 = c_m0_n0_m1_n1_m2_n2_grid_desc.GetLength(I5); return transform_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_desc,
static_assert(N2 == mfma_type.num_threads_per_blk, "");
static_assert(
M2 == (mfma_type.num_groups_per_blk * mfma_type.num_output_blks * mfma_type.group_size),
"");
return transform_dynamic_tensor_descriptor(
c_m0_n0_m1_n1_m2_n2_grid_desc,
make_tuple(make_pass_through_transform(M0), make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0), make_pass_through_transform(N0),
make_pass_through_transform(M1), make_pass_through_transform(M1),
......
...@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256; constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
...@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
constexpr index_t GemmNPerWave = 32; constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8; constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
...@@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( ...@@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks = constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}), Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})); Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment