Commit 420c0312 authored by Paul's avatar Paul
Browse files

Use enum for data types

parent 3905f4a2
......@@ -24,6 +24,8 @@ enum class DataType {
Int32
};
std::string ToString(DataType dt);
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders();
std::size_t integer_divide_ceil(std::size_t x, std::size_t y);
......
......@@ -24,11 +24,11 @@ struct Problem
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsLayout = {};
std::string ADataType = "ck::half_t";
std::string BDataType = "ck::half_t";
std::string EDataType = "ck::half_t";
std::vector<std::string> DsDataType = {};
std::vector<bool> DsTrans = {};
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType EDataType = DataType::Half;
std::vector<DataType> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
......@@ -45,19 +45,14 @@ struct Problem
static const std::size_t n_per_block_idx = 18;
static const std::size_t k_per_block_idx = 19;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
std::string GetIncludeHeader() const;
std::string MakeLayoutTuple(const std::vector<bool>& layouts) const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
std::string MakeTypeTuple(const std::vector<std::string>& types) const;
private:
std::vector<std::string> GetInstances(const std::string& arch) const;
Solution MakeSolution(std::size_t idx, const std::string& arch) const;
public:
std::string GetIncludeHeader() const;
std::vector<Solution> GetSolutions(const std::string& arch) const;
};
} // namespace device_gemm_multiple_d
......
......@@ -5,6 +5,17 @@
namespace ck {
namespace host {
std::string ToString(DataType dt)
{
switch (dt) {
case DataType::Float: return "float";
case DataType::Half: return "ck::half_t";
case DataType::Int8: return "int8_t";
case DataType::Int32: return "int32_t";
}
throw std::runtime_error("Incorrect data type");
}
std::unordered_map<std::string, std::pair<const char*,const char*>> GetHeaders()
{
return ck_headers();
......
......@@ -46,7 +46,7 @@ const std::unordered_set<std::string>& get_xdlop_archs()
std::vector<std::string> Problem::GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
const bool quantize = ADataType == "int8_t" and BDataType == "int8_t";
const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8;
if (get_xdlop_archs().find(arch) != get_xdlop_archs().end())
{
instance::gemm_add_add_fastgelu_instances all_instances{};
......@@ -62,7 +62,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
return instances;
}
std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const
std::string MakeLayoutTuple(const std::vector<bool>& layouts)
{
std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin();
......@@ -77,13 +77,13 @@ std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const
return layout_tuple + ">";
}
std::string Problem::MakeTypeTuple(const std::vector<std::string>& types) const
std::string MakeTypeTuple(const std::vector<DataType>& types)
{
std::string type_tuple = "ck::Tuple<";
auto it = types.begin();
while(it != types.end())
{
type_tuple += *it;
type_tuple += ToString(*it);
it = std::next(it);
if (it != types.end())
type_tuple += ", ";
......@@ -98,14 +98,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
if (ADataType == "int8_t" and BDataType == "int8_t")
if (ADataType == DataType::Int8 and BDataType == DataType::Int8)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "ck::half_t"; }))
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; }))
{
params[params.size() - 3] = "8";
}
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "float"; }))
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; }))
{
params[params.size() - 3] = "4";
}
......@@ -113,10 +113,10 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsLayout);
params[ds_layout_idx] = MakeLayoutTuple(DsTrans);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = EDataType;
params[e_data_type_idx] = ToString(EDataType);
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment