Commit 77190058 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 421734ae
...@@ -120,22 +120,24 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -120,22 +120,24 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) __host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01) : M_(M), N_(N), M01_(M01)
{ {
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8) index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt( : BlockToCTileMap_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
...@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -151,13 +153,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) __host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const __host__ __device__ constexpr bool
CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{ {
return true; return true;
} }
......
...@@ -17,7 +17,8 @@ struct Solution ...@@ -17,7 +17,8 @@ struct Solution
std::size_t grid_size; std::size_t grid_size;
}; };
enum class DataType { enum class DataType
{
Half, Half,
Float, Float,
Int8, Int8,
...@@ -26,7 +27,7 @@ enum class DataType { ...@@ -26,7 +27,7 @@ enum class DataType {
std::string ToString(DataType dt); std::string ToString(DataType dt);
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders(); std::unordered_map<std::string, std::pair<const char*, const char*>> GetHeaders();
std::size_t integer_divide_ceil(std::size_t x, std::size_t y); std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <numeric> #include <numeric>
#include "ck/host/common.hpp" #include "ck/host/common.hpp"
namespace ck { namespace ck {
namespace host { namespace host {
namespace device_gemm_multiple_d { namespace device_gemm_multiple_d {
...@@ -49,7 +48,7 @@ struct Problem ...@@ -49,7 +48,7 @@ struct Problem
std::vector<Solution> GetSolutions(const std::string& arch) const; std::vector<Solution> GetSolutions(const std::string& arch) const;
private: private:
std::vector<std::string> GetInstances(const std::string& arch) const; std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const; Solution MakeSolution(std::size_t idx, const std::string& arch) const;
......
...@@ -8,7 +8,8 @@ namespace host { ...@@ -8,7 +8,8 @@ namespace host {
std::string ToString(DataType dt) std::string ToString(DataType dt)
{ {
switch (dt) { switch(dt)
{
case DataType::Float: return "float"; case DataType::Float: return "float";
case DataType::Half: return "ck::half_t"; case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t"; case DataType::Int8: return "int8_t";
...@@ -17,7 +18,7 @@ std::string ToString(DataType dt) ...@@ -17,7 +18,7 @@ std::string ToString(DataType dt)
throw std::runtime_error("Incorrect data type"); throw std::runtime_error("Incorrect data type");
} }
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders() std::unordered_map<std::string, std::pair<const char*, const char*>> GetHeaders()
{ {
return ck_headers(); return ck_headers();
} }
......
...@@ -33,8 +33,7 @@ std::size_t GetGridSize(const std::size_t m, ...@@ -33,8 +33,7 @@ std::size_t GetGridSize(const std::size_t m,
const std::size_t m_per_block, const std::size_t m_per_block,
const std::size_t n_per_block) const std::size_t n_per_block)
{ {
return integer_divide_ceil(m, m_per_block) * return integer_divide_ceil(m, m_per_block) * integer_divide_ceil(n, n_per_block);
integer_divide_ceil(n, n_per_block);
} }
const std::unordered_set<std::string>& get_xdlop_archs() const std::unordered_set<std::string>& get_xdlop_archs()
...@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const ...@@ -47,7 +46,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{ {
std::vector<std::string> instances; std::vector<std::string> instances;
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8; const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) if(get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{ {
ck::host::instance::gemm_add_add_fastgelu_instances all_instances{}; ck::host::instance::gemm_add_add_fastgelu_instances all_instances{};
if(TransA and TransB) if(TransA and TransB)
...@@ -68,9 +67,10 @@ std::string MakeLayoutTuple(const std::vector<bool>& layouts) ...@@ -68,9 +67,10 @@ std::string MakeLayoutTuple(const std::vector<bool>& layouts)
auto it = layouts.begin(); auto it = layouts.begin();
while(it != layouts.end()) while(it != layouts.end())
{ {
layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; layout_tuple +=
*it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
it = std::next(it); it = std::next(it);
if (it != layouts.end()) if(it != layouts.end())
layout_tuple += ", "; layout_tuple += ", ";
} }
...@@ -85,7 +85,7 @@ std::string MakeTypeTuple(const std::vector<DataType>& types) ...@@ -85,7 +85,7 @@ std::string MakeTypeTuple(const std::vector<DataType>& types)
{ {
type_tuple += ToString(*it); type_tuple += ToString(*it);
it = std::next(it); it = std::next(it);
if (it != types.end()) if(it != types.end())
type_tuple += ", "; type_tuple += ", ";
} }
return type_tuple + ">"; return type_tuple + ">";
...@@ -98,14 +98,16 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -98,14 +98,16 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::vector<std::string> params(std::istream_iterator<std::string>{iss}, std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>()); std::istream_iterator<std::string>());
if (ADataType == DataType::Int8 and BDataType == DataType::Int8) if(ADataType == DataType::Int8 and BDataType == DataType::Int8)
{ {
// Change CBlockTransfer ScalarPerVector if Ds contains other types // Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) if(std::any_of(
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
{ {
params[params.size() - 3] = "8"; params[params.size() - 3] = "8";
} }
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) if(std::any_of(
DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
...@@ -128,10 +130,11 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -128,10 +130,11 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block); const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block);
params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block);
std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{}, std::string str = std::accumulate(
[](const std::string& a, const std::string& b) { params.begin() + 1,
return a.empty() ? b : a + ", " + b; params.end(),
}); std::string{},
[](const std::string& a, const std::string& b) { return a.empty() ? b : a + ", " + b; });
str = params.front() + "< " + str + ">"; str = params.front() + "< " + str + ">";
return Solution{str, block_size, grid_size}; return Solution{str, block_size, grid_size};
...@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const ...@@ -146,7 +149,7 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
{ {
std::vector<Solution> solutions; std::vector<Solution> solutions;
const std::size_t num_instances = GetInstances(arch).size(); const std::size_t num_instances = GetInstances(arch).size();
for (std::size_t i = 0; i < num_instances; ++i) for(std::size_t i = 0; i < num_instances; ++i)
{ {
solutions.push_back(MakeSolution(i, arch)); solutions.push_back(MakeSolution(i, arch));
} }
...@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const ...@@ -154,7 +157,6 @@ std::vector<Solution> Problem::GetSolutions(const std::string& arch) const
return solutions; return solutions;
} }
} // namespace device_gemm_multiple_d } // namespace device_gemm_multiple_d
} // namespace host } // namespace host
} // namespace ck } // namespace ck
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
bool test_Problem() bool test_Problem()
{ {
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -27,9 +28,20 @@ bool test_Problem() ...@@ -27,9 +28,20 @@ bool test_Problem()
bool pass = true; bool pass = true;
pass &= include_header == "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; pass &= include_header ==
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp";
pass &= solutions.size() == 42; pass &= solutions.size() == 42;
pass &= template_str == "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, ck::half_t, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::element_wise::Passthrough, ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, 1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>"; pass &= template_str ==
"ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< "
"ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::Tuple<>, "
"ck::tensor_layout::gemm::RowMajor, ck::half_t, ck::half_t, float, float, ck::Tuple<>, "
"ck::half_t, ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::element_wise::Passthrough, "
"ck::tensor_operation::device::GemmSpecialization::Default, 1, 256, 256, 128, 32, 8, "
"8, 32, 32, 4, 2, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, "
"8, 8, 1, ck::Sequence<4,64,1>, ck::Sequence<1,0,2>, ck::Sequence<1,0,2>, 2, 8, 8, 1, "
"1, 1, ck::Sequence<1,32,1,8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v1>";
pass &= grid_size == 2; pass &= grid_size == 2;
pass &= block_size == 256; pass &= block_size == 256;
...@@ -40,8 +52,9 @@ bool test_GetGemmSpec() ...@@ -40,8 +52,9 @@ bool test_GetGemmSpec()
{ {
bool pass = true; bool pass = true;
{ {
//PadMNK // PadMNK
auto problem = ck::host::device_gemm_multiple_d::Problem{255, auto problem = ck::host::device_gemm_multiple_d::Problem{
255,
255, 255,
255, 255,
false, false,
...@@ -62,8 +75,9 @@ bool test_GetGemmSpec() ...@@ -62,8 +75,9 @@ bool test_GetGemmSpec()
pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos; pass &= template_str.find("GemmSpecialization::MNKPadding") != std::string::npos;
} }
{ {
//Default // Default
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -91,8 +105,9 @@ bool test_GetInstances() ...@@ -91,8 +105,9 @@ bool test_GetInstances()
{ {
bool pass = true; bool pass = true;
{ {
//Col Col Fp16 // Col Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
true, true,
...@@ -109,8 +124,9 @@ bool test_GetInstances() ...@@ -109,8 +124,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 51; pass &= problem.GetSolutions("gfx90a").size() == 51;
} }
{ {
//Col Row Fp16 // Col Row Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
true, true,
...@@ -127,8 +143,9 @@ bool test_GetInstances() ...@@ -127,8 +143,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 51; pass &= problem.GetSolutions("gfx90a").size() == 51;
} }
{ {
//Row Col Fp16 // Row Col Fp16
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -145,8 +162,9 @@ bool test_GetInstances() ...@@ -145,8 +162,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 42; pass &= problem.GetSolutions("gfx90a").size() == 42;
} }
{ {
//Row Row Int8 // Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -163,8 +181,9 @@ bool test_GetInstances() ...@@ -163,8 +181,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 48; pass &= problem.GetSolutions("gfx90a").size() == 48;
} }
{ {
//Col Col Int8 // Col Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
true, true,
...@@ -181,8 +200,9 @@ bool test_GetInstances() ...@@ -181,8 +200,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 48; pass &= problem.GetSolutions("gfx90a").size() == 48;
} }
{ {
//Col Row Int8 // Col Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
true, true,
...@@ -199,8 +219,9 @@ bool test_GetInstances() ...@@ -199,8 +219,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 48; pass &= problem.GetSolutions("gfx90a").size() == 48;
} }
{ {
//Row Col Int8 // Row Col Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -217,8 +238,9 @@ bool test_GetInstances() ...@@ -217,8 +238,9 @@ bool test_GetInstances()
pass &= problem.GetSolutions("gfx90a").size() == 39; pass &= problem.GetSolutions("gfx90a").size() == 39;
} }
{ {
//Row Row Int8 // Row Row Int8
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -243,7 +265,8 @@ bool test_MakeLayoutsTuple() ...@@ -243,7 +265,8 @@ bool test_MakeLayoutsTuple()
bool pass = true; bool pass = true;
{ {
// Empty Tuple // Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -264,7 +287,8 @@ bool test_MakeLayoutsTuple() ...@@ -264,7 +287,8 @@ bool test_MakeLayoutsTuple()
} }
{ {
// RowColRow Tuple // RowColRow Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -281,7 +305,10 @@ bool test_MakeLayoutsTuple() ...@@ -281,7 +305,10 @@ bool test_MakeLayoutsTuple()
const auto solutions = problem.GetSolutions("gfx90a"); const auto solutions = problem.GetSolutions("gfx90a");
const auto& solution = solutions.at(0); const auto& solution = solutions.at(0);
const auto template_str = solution.template_str; const auto template_str = solution.template_str;
pass &= template_str.find("ck::Tuple<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") != std::string::npos; pass &= template_str.find(
"ck::Tuple<ck::tensor_layout::gemm::RowMajor, "
"ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor>") !=
std::string::npos;
} }
return pass; return pass;
...@@ -292,7 +319,8 @@ bool test_MakeTypeTuple() ...@@ -292,7 +319,8 @@ bool test_MakeTypeTuple()
bool pass = true; bool pass = true;
{ {
// Empty Tuple // Empty Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
...@@ -313,7 +341,8 @@ bool test_MakeTypeTuple() ...@@ -313,7 +341,8 @@ bool test_MakeTypeTuple()
} }
{ {
// Half Int8 Tuple // Half Int8 Tuple
auto problem = ck::host::device_gemm_multiple_d::Problem{256, auto problem = ck::host::device_gemm_multiple_d::Problem{
256,
256, 256,
256, 256,
false, false,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment