Commit 3279fca1 authored by ltqin's avatar ltqin
Browse files

carewrite lculate gridsize

parent 62fdce6d
...@@ -223,12 +223,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -223,12 +223,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const auto M = c_m_n_grid_desc.GetLength(I0); const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1); const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch;
return grid_size_mn; return grid_size;
} }
__host__ __device__ static constexpr index_t CalculateGridSize(const index_t M, const index_t N) __host__ __device__ static constexpr index_t CalculateBatchGridSize(const index_t M,
const index_t N)
{ {
const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size_mn = (M / MPerBlock) * (N / NPerBlock);
...@@ -347,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -347,7 +348,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateGridSize(M, N); const auto b_grid_size = CalculateBatchGridSize(M, N);
const auto k_batch_id = get_block_1d_id() / b_grid_size; const auto k_batch_id = get_block_1d_id() / b_grid_size;
const auto block_id_in_batch = get_block_1d_id() % b_grid_size; const auto block_id_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 200000) if(get_block_1d_id() == 200000)
......
...@@ -147,11 +147,9 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -147,11 +147,9 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * KBatch;
{ {
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size std::cout << "gridSize : " << grid_size << grid_size << std::endl;
<< std::endl;
} }
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment