Commit dc65f4c6 authored by Alan Turner's avatar Alan Turner
Browse files

Use vectors for Ds types and layouts params

parent e2878e25
......@@ -26,25 +26,6 @@ struct Solution
std::string template_str;
index_t block_size;
index_t grid_size;
Solution(std::string s, index_t b, index_t g) : template_str(s), block_size(b), grid_size(g)
{}
auto GetStr() const
{
return template_str;
}
auto GetBlockSize() const
{
return block_size;
}
auto GetGridSize() const
{
return grid_size;
}
};
std::string GetGemmSpec(const index_t m,
......@@ -84,20 +65,20 @@ const std::unordered_set<std::string>& get_xdlop_archs()
struct Problem
{
index_t M;
index_t N;
index_t K;
index_t NumDTensors;
bool TransA;
bool TransB;
bool TransCDE;
std::string ADataType;
std::string BDataType;
std::string CDEDataType;
std::string AElementOp;
std::string BElementOp;
std::string CDEElementOp;
std::string CDELayout;
index_t M = 0;
index_t N = 0;
index_t K = 0;
bool TransA = false;
bool TransB = false;
bool TransE = false;
std::vector<bool> DsLayout = {};
std::string ADataType = "ck::half_t";
std::string BDataType = "ck::half_t";
std::string EDataType = "ck::half_t";
std::vector<std::string> DsDataType = {};
std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough";
std::string CDEElementOp = "ck::Tuple<>";
static const index_t ds_layout_idx = 3;
static const index_t ds_data_type_idx = 9;
......@@ -110,6 +91,7 @@ struct Problem
static const index_t n_per_block_idx = 18;
static const index_t k_per_block_idx = 19;
private:
auto GetInstances(const std::string& arch) const
{
std::vector<std::string> instances;
......@@ -128,45 +110,33 @@ struct Problem
return instances;
}
auto GetHeaders() const
{
return ck_headers();
}
auto GetIncludeHeader() const
auto MakeLayoutTuple(const std::vector<bool>& layouts) const
{
return instance::gemm_add_add_fastgelu_instances{}.get_include_header();
std::string layout_tuple = "ck::Tuple<";
auto it = layouts.begin();
while(it != layouts.end())
{
layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor";
it = std::next(it);
if (it != layouts.end())
layout_tuple += ", ";
}
return layout_tuple + ">";
}
Problem(index_t m,
index_t n,
index_t k,
index_t numDTensors,
bool transA,
bool transB,
bool transCDE,
std::string aDataType,
std::string bDataType,
std::string cdeDataType,
std::string aElementOp,
std::string bElementOp,
std::string cdeElementOp,
std::string cdeLayout)
: M(m),
N(n),
K(k),
NumDTensors(numDTensors),
TransA(transA),
TransB(transB),
TransCDE(transCDE),
ADataType(aDataType),
BDataType(bDataType),
CDEDataType(cdeDataType),
AElementOp(aElementOp),
BElementOp(bElementOp),
CDEElementOp(cdeElementOp),
CDELayout(cdeLayout)
auto MakeTypeTuple(const std::vector<std::string>& types) const
{
std::string type_tuple = "ck::Tuple<";
auto it = types.begin();
while(it != types.end())
{
type_tuple += *it;
it = std::next(it);
if (it != types.end())
type_tuple += ", ";
}
return type_tuple + ">";
}
auto MakeSolution(index_t idx, const std::string& arch) const
......@@ -178,8 +148,8 @@ struct Problem
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = CDELayout;
params[ds_data_type_idx] = CDEDataType;
params[ds_layout_idx] = MakeLayoutTuple(DsLayout);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
......@@ -201,6 +171,17 @@ struct Problem
return Solution{str, block_size, grid_size};
}
public:
auto GetHeaders() const
{
return ck_headers();
}
auto GetIncludeHeader() const
{
return instance::gemm_add_add_fastgelu_instances{}.get_include_header();
}
auto GetSolutions(const std::string& arch) const
{
std::vector<Solution> solutions;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment