Commit 72f3eb67 authored by ltqin's avatar ltqin
Browse files

add split k functiion

parent 1d4f5453
...@@ -195,6 +195,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -195,6 +195,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0);
} }
__host__ __device__ static constexpr index_t
CalculateKBatch(const CMNGridDesc& c_m_n_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
constexpr auto MAX_GRID = 2048;
auto batch = K0 / KPerBlock;
assert(K0 % KPerBlock == 0);
index_t div = 1;
while(batch * grid_size > MAX_GRID && batch > div)
{
div++;
if(batch % div == 0)
batch = batch / div;
}
batch = std::max(1, batch);
return batch;
}
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{ {
......
...@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -122,7 +122,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
} }
auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
{
std::cout << "k batch number is: " << kbatch << std::endl;
}
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
{ {
throw std::runtime_error( throw std::runtime_error(
...@@ -138,7 +141,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -138,7 +141,7 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc) * kbatch;
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment