Commit b8442b51 authored by ltqin's avatar ltqin
Browse files

add a matrix unmerge

parent 5efcb64b
......@@ -198,14 +198,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
__host__ __device__ static constexpr index_t
CalculateKBatch(const CMNGridDesc& c_m_n_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
constexpr auto MAX_GRID = 2048;
const index_t grid_size = CalculateGridSize(c_m_n_grid_desc);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
constexpr auto MAX_GRID = 2048;
const index_t grid_size_mn = CalculateMNGridSize(c_m_n_grid_desc);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
auto batch = K0 / KPerBlock;
assert(K0 % KPerBlock == 0);
index_t div = 1;
while(batch * grid_size > MAX_GRID && batch > div)
while(batch * grid_size_mn > MAX_GRID && batch > div)
{
div++;
if(batch % div == 0)
......@@ -217,16 +217,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
CalculateMNGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
return grid_size_mn;
}
__host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const index_t kbatch)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)),
make_pass_through_transform(M),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return a_b_k0_m_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
......@@ -298,6 +313,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc);
if(get_block_1d_id() == 0)
printf("*****kbatch : %d", kbatch);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// divide block work by [M, N]
......
......@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
{
std::cout << "k batch number is: " << kbatch << std::endl;
}
......@@ -135,13 +138,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size_mn = GridwiseGemm::CalculateMNGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * kbatch;
{
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment