Commit dbc971be authored by wangshaojie6's avatar wangshaojie6
Browse files

wip for gridwise

parent 6985af40
......@@ -74,6 +74,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto K0PerBlock = Number<KPerBlock / AK1>{};
static constexpr auto BaseMultK0 = 2;
static constexpr auto MultiK0 = BaseMultK0 * 1;
......@@ -81,30 +83,36 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
static constexpr auto K1 = Number<AK1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL);
static constexpr index_t M0Waves = M0PerBlock / (M0XdlPerWave * M0PerXDL);
static constexpr index_t N0Waves = N0PerBlock / (N0XdlPerWave * N0PerXDL);
static constexpr auto xdlops_gemm0 = XdlopsGemm<FloatAB, M0PerXDL, N0PerXDL, K1>{};
static constexpr index_t K0PerThread0 = K0PerBlock / xdlops_gemm0.K0PerXdlops;
static constexpr index_t M1Waves = M1PerBlock / (M1XdlPerWave * M1PerXDL);
static constexpr index_t N1Waves = N1PerBlock / (N1XdlPerWave * N1PerXDL);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
static constexpr auto xdlops_gemm1 = XdlopsGemm<FloatAB, M1PerXDL, N1PerXDL, K1>{};
static constexpr index_t K0PerThread1 = K0PerBlock / xdlops_gemm1.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
constexpr auto max_lds_align = AK1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<M0PerBlock>{}, K1),
make_tuple(Number<M0PerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<M0PerBlock>{}, K1),
max_lds_align);
}
}();
......@@ -112,10 +120,34 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
return a_block_desc_k0_m_k1;
}
__host__ __device__ static constexpr auto GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1()
{
constexpr auto max_lds_align = B1K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto b1_block_desc_k0_n_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<N1PerBlock>{}, K1),
make_tuple(Number<N1PerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * MultiK0>{}, Number<N1PerBlock>{}, K1),
max_lds_align);
}
}();
return b1_block_desc_k0_n_k1;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto b1_block_desc_k0_n_k1 = GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1();
constexpr auto max_lds_align = K1;
......@@ -179,7 +211,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
const index_t grid_size = (M / M0PerBlock) * (N / N1PerBlock);
return grid_size;
}
......@@ -193,21 +225,21 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
}
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const B0GridDesc_K0_N_K1& b0_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = b0_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b0_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
const auto b0_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b0_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_tuple(K0 / K0PerBlock, xdlops_gemm0.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
return b0_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
__device__ static auto GetWaveIdx()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment