Commit dbc971be authored by wangshaojie6's avatar wangshaojie6
Browse files

wip for gridwise

parent 6985af40
...@@ -74,6 +74,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -74,6 +74,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static constexpr auto I6 = Number<6>{}; static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
static constexpr auto K0PerBlock = Number<KPerBlock / AK1>{};
static constexpr auto BaseMultK0 = 2; static constexpr auto BaseMultK0 = 2;
static constexpr auto MultiK0 = BaseMultK0 * 1; static constexpr auto MultiK0 = BaseMultK0 * 1;
...@@ -81,30 +83,36 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -81,30 +83,36 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static constexpr auto K1 = Number<AK1>{}; static constexpr auto K1 = Number<AK1>{};
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); static constexpr index_t M0Waves = M0PerBlock / (M0XdlPerWave * M0PerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); static constexpr index_t N0Waves = N0PerBlock / (N0XdlPerWave * N0PerXDL);
static constexpr auto xdlops_gemm0 = XdlopsGemm<FloatAB, M0PerXDL, N0PerXDL, K1>{};
static constexpr index_t K0PerThread0 = K0PerBlock / xdlops_gemm0.K0PerXdlops;
static constexpr index_t M1Waves = M1PerBlock / (M1XdlPerWave * M1PerXDL);
static constexpr index_t N1Waves = N1PerBlock / (N1XdlPerWave * N1PerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{}; static constexpr auto xdlops_gemm1 = XdlopsGemm<FloatAB, M1PerXDL, N1PerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t K0PerThread1 = K0PerBlock / xdlops_gemm1.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = AK1;
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() { constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM) if constexpr(ABlockLdsExtraM)
{ {
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock * MultiK0>{}, Number<M0PerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1)); make_tuple(Number<M0PerBlock + 1>{} * K1, K1, I1));
} }
else else
{ {
return make_naive_tensor_descriptor_aligned( return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1), make_tuple(Number<K0PerBlock * MultiK0>{}, Number<M0PerBlock>{}, K1),
max_lds_align); max_lds_align);
} }
}(); }();
...@@ -112,10 +120,34 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -112,10 +120,34 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
return a_block_desc_k0_m_k1; return a_block_desc_k0_m_k1;
} }
__host__ __device__ static constexpr auto GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = B1K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_k0_n_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N1PerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<N1PerBlock>{}, K1),
max_lds_align);
}
}();
return b1_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b1_block_desc_k0_n_k1 = GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -179,7 +211,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -179,7 +211,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size = (M / M0PerBlock) * (N / N1PerBlock);
return grid_size; return grid_size;
} }
...@@ -193,21 +225,21 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1 ...@@ -193,21 +225,21 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const B0GridDesc_K0_N_K1& b0_grid_desc_k0_n_k1)
{ {
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b0_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b0_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( const auto b0_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1, b0_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform( make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), make_tuple(K0 / K0PerBlock, xdlops_gemm0.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple( make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; return b0_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
} }
__device__ static auto GetWaveIdx() __device__ static auto GetWaveIdx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment