Commit 420c0312 authored by Paul's avatar Paul
Browse files

Use enum for data types

parent 3905f4a2
...@@ -24,6 +24,8 @@ enum class DataType { ...@@ -24,6 +24,8 @@ enum class DataType {
Int32 Int32
}; };
std::string ToString(DataType dt);
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders(); std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders();
std::size_t integer_divide_ceil(std::size_t x, std::size_t y); std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
......
...@@ -24,11 +24,11 @@ struct Problem ...@@ -24,11 +24,11 @@ struct Problem
bool TransA = false; bool TransA = false;
bool TransB = false; bool TransB = false;
bool TransE = false; bool TransE = false;
std::vector<bool> DsLayout = {}; std::vector<bool> DsTrans = {};
std::string ADataType = "ck::half_t"; DataType ADataType = DataType::Half;
std::string BDataType = "ck::half_t"; DataType BDataType = DataType::Half;
std::string EDataType = "ck::half_t"; DataType EDataType = DataType::Half;
std::vector<std::string> DsDataType = {}; std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>"; std::string CDEElementOp = "ck::Tuple<>";
...@@ -45,19 +45,14 @@ struct Problem ...@@ -45,19 +45,14 @@ struct Problem
static const std::size_t n_per_block_idx = 18; static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19; static const std::size_t k_per_block_idx = 19;
private: std::string GetIncludeHeader() const;
std::vector<std::string> GetInstances(const std::string& arch) const;
std::string MakeLayoutTuple(const std::vector<bool>& layouts) const; std::vector<Solution> GetSolutions(const std::string& arch) const;
std::string MakeTypeTuple(const std::vector<std::string>& types) const; private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const; Solution MakeSolution(std::size_t idx, const std::string& arch) const;
public:
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
}; };
} // namespace device_gemm_multiple_d } // namespace device_gemm_multiple_d
......
...@@ -5,6 +5,17 @@ ...@@ -5,6 +5,17 @@
namespace ck { namespace ck {
namespace host { namespace host {
std::string ToString(DataType dt)
{
switch (dt) {
case DataType::Float: return "float";
case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t";
case DataType::Int32: return "int32_t";
}
throw std::runtime_error("Incorrect data type");
}
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders() std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders()
{ {
return ck_headers(); return ck_headers();
......
...@@ -46,7 +46,7 @@ const std::unordered_set<std::string>& get_xdlop_archs() ...@@ -46,7 +46,7 @@ const std::unordered_set<std::string>& get_xdlop_archs()
std::vector<std::string> Problem::GetInstances(const std::string& arch) const std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{ {
std::vector<std::string> instances; std::vector<std::string> instances;
const bool quantize = ADataType == "int8_t" and BDataType == "int8_t"; const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{ {
instance::gemm_add_add_fastgelu_instances all_instances{}; instance::gemm_add_add_fastgelu_instances all_instances{};
...@@ -62,7 +62,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const ...@@ -62,7 +62,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
return instances; return instances;
} }
std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const std::string MakeLayoutTuple(const std::vector<bool>& layouts)
{ {
std::string layout_tuple = "ck::Tuple<"; std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin(); auto it = layouts.begin();
...@@ -77,13 +77,13 @@ std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const ...@@ -77,13 +77,13 @@ std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const
return layout_tuple + ">"; return layout_tuple + ">";
} }
std::string Problem::MakeTypeTuple(const std::vector<std::string>& types) const std::string MakeTypeTuple(const std::vector<DataType>& types)
{ {
std::string type_tuple = "ck::Tuple<"; std::string type_tuple = "ck::Tuple<";
auto it = types.begin(); auto it = types.begin();
while(it != types.end()) while(it != types.end())
{ {
type_tuple += *it; type_tuple += ToString(*it);
it = std::next(it); it = std::next(it);
if (it != types.end()) if (it != types.end())
type_tuple += ", "; type_tuple += ", ";
...@@ -98,14 +98,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -98,14 +98,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::vector<std::string> params(std::istream_iterator<std::string>{iss}, std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>()); std::istream_iterator<std::string>());
if (ADataType == "int8_t" and BDataType == "int8_t") if (ADataType == DataType::Int8 and BDataType == DataType::Int8)
{ {
// Change CBlockTransfer ScalarPerVector if Ds contains other types // Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "ck::half_t"; })) if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
{ {
params[params.size() - 3] = "8"; params[params.size() - 3] = "8";
} }
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "float"; })) if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
{ {
params[params.size() - 3] = "4"; params[params.size() - 3] = "4";
} }
...@@ -113,10 +113,10 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const ...@@ -113,10 +113,10 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
params[a_elementwise_op_idx] = AElementOp; params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp; params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsLayout); params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType); params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp; params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = EDataType; params[e_data_type_idx] = ToString(EDataType);
auto block_size_str = params[block_size_idx]; auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx]; auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx]; auto n_per_block_str = params[n_per_block_idx];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment