Commit 5efcb64b authored by ltqin's avatar ltqin
Browse files

redefine code

parent 72f3eb67
......@@ -198,13 +198,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
__host__ __device__ static constexpr index_t
CalculateKBatch(const CMNGridDesc& c_m_n_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
constexpr auto MAX_GRID = 2048;
const index_t grid_size = CalculateGridSize(c_m_n_grid_desc);
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
auto batch = K0 / KPerBlock;
assert(K0 % KPerBlock == 0);
index_t div = 1;
......
......@@ -141,8 +141,12 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc) * kbatch;
const index_t grid_size_mn = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
const index_t grid_size = grid_size_mn * kbatch;
{
std::cout << "mxn gridSize : " << grid_size_mn << " finally grid_size : " << grid_size
<< std::endl;
}
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB,
FloatC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment