Commit 7e2b9da3 authored by Adam Osewski's avatar Adam Osewski
Browse files

B2E map calculation of workspace size

parent 98220c32
...@@ -1176,7 +1176,7 @@ struct BlockToCTileMap_LinearKSplit ...@@ -1176,7 +1176,7 @@ struct BlockToCTileMap_LinearKSplit
{ {
} }
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{ {
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
...@@ -1185,7 +1185,8 @@ struct BlockToCTileMap_LinearKSplit ...@@ -1185,7 +1185,8 @@ struct BlockToCTileMap_LinearKSplit
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
} }
...@@ -1267,6 +1268,20 @@ struct BlockToCTileMap_LinearKSplit ...@@ -1267,6 +1268,20 @@ struct BlockToCTileMap_LinearKSplit
__host__ __device__ index_t GetTileNIdx() const { return N0_idx_; } __host__ __device__ index_t GetTileNIdx() const { return N0_idx_; }
__host__ __device__ index_t GetTileKIdx() const { return K0_idx_; } __host__ __device__ index_t GetTileKIdx() const { return K0_idx_; }
static __host__ uint32_t GetAccWorkspaceSize(uint32_t acc_element_bytes, uint32_t grid_size)
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes = MPerBlock * NPerBlock * grid_size * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
static __device__ uint32_t GetAccWorkspaceSize(uint32_t acc_element_bytes)
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes = MPerBlock * NPerBlock * get_grid_size() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
private: private:
index_t M_; index_t M_;
index_t N_; index_t N_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment