Commit a088771c authored by ltqin's avatar ltqin
Browse files

add b matrix unmerge k0

parent b8442b51
......@@ -242,6 +242,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return a_b_k0_m_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBBK0NK1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const index_t kbatch)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto b_b_k0_n_k1_grid_desc = transform_tensor_descriptor(
b_k0_n_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)),
make_pass_through_transform(N),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return b_b_k0_n_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
......
......@@ -123,9 +123,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
// const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
// const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, kbatch);
{
std::cout << "k batch number is: " << kbatch << std::endl;
}
......@@ -138,7 +139,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
// using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
// using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment