"src/include/Array.hpp" did not exist on "8c385cf5cf25219d235e52be50d1d3f4a0a21f87"
Commit e0d67738 authored by Adam Osewski's avatar Adam Osewski
Browse files

Take contiguous dim size when calculating dram vector load size.

parent d79d1a38
...@@ -17,38 +17,43 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -17,38 +17,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
/** /**
* @brief Get the maximum global memory vector load size. * @brief Get the maximum global memory vector load size.
* *
* @tparam Problem The UniversalGemmPipelineProblem object. * @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering. * @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B). * @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size. * @return Maximum DRAM vector load size.
*/ */
template <typename Problem, typename DataType, index_t MNPerBlock> template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{ {
// TODO this does not take into accout the size of contiguous dim!
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even! // Assume DataType is even!
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
{ {
return (16 / sizeof(DataType)); return (16 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
{ {
return (8 / sizeof(DataType)); return (8 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 4) elements_per_thread % (4 / sizeof(DataType)) == 0)
{ {
return (4 / sizeof(DataType)); return (4 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 2) elements_per_thread % (2 / sizeof(DataType)) == 0)
{ {
return (2 / sizeof(DataType)); return (2 / sizeof(DataType));
} }
...@@ -61,19 +66,37 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -61,19 +66,37 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock>(); constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock>(); constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
} }
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment