"symphony/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "238679917e522273f25b71989c6f486111b0b8b7"
Commit b8442b51 authored by ltqin's avatar ltqin
Browse files

add a matrix unmerge

parent 5efcb64b
...@@ -199,13 +199,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -199,13 +199,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
CalculateKBatch(const CMNGridDesc& c_m_n_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc) CalculateKBatch(const CMNGridDesc& c_m_n_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{ {
constexpr auto MAX_GRID = 2048; constexpr auto MAX_GRID = 2048;
const index_t grid_size = CalculateGridSize(c_m_n_grid_desc); const index_t grid_size_mn = CalculateMNGridSize(c_m_n_grid_desc);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
auto batch = K0 / KPerBlock; auto batch = K0 / KPerBlock;
assert(K0 % KPerBlock == 0); assert(K0 % KPerBlock == 0);
index_t div = 1; index_t div = 1;
while(batch * grid_size > MAX_GRID && batch > div) while(batch * grid_size_mn > MAX_GRID && batch > div)
{ {
div++; div++;
if(batch % div == 0) if(batch % div == 0)
...@@ -217,16 +217,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -217,16 +217,31 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
} }
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) CalculateMNGridSize(const CMNGridDesc& c_m_n_grid_desc)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock);
return grid_size; return grid_size_mn;
} }
__host__ __device__ static constexpr auto
MakeABK0MK1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const index_t kbatch)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto a_b_k0_m_k1_grid_desc = transform_tensor_descriptor(
a_k0_m_k1_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(kbatch, K0 / kbatch)),
make_pass_through_transform(M),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
return a_b_k0_m_k1_grid_desc;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
...@@ -298,6 +313,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -298,6 +313,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc);
if(get_block_1d_id() == 0)
printf("*****kbatch : %d", kbatch);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// divide block work by [M, N] // divide block work by [M, N]
......
...@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc); const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
{ {
std::cout << "k batch number is: " << kbatch << std::endl; std::cout << "k batch number is: " << kbatch << std::endl;
} }
...@@ -135,13 +138,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -135,13 +138,14 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size_mn = GridwiseGemm::CalculateMNGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * kbatch; const index_t grid_size = grid_size_mn * kbatch;
{ {
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment