"vscode:/vscode.git/clone" did not exist on "a81695584f537f921df8adcfe856ce12a95b88d1"
Commit 61386bf9 authored by Alan Turner's avatar Alan Turner
Browse files

Add edatatype and scalars_per_vector workaround

parent 6289e36f
......@@ -82,6 +82,7 @@ struct Problem
static const index_t ds_layout_idx = 3;
static const index_t ds_data_type_idx = 9;
static const index_t e_data_type_idx = 10;
static const index_t a_elementwise_op_idx = 11;
static const index_t b_elementwise_op_idx = 12;
static const index_t ds_elementwise_op_idx = 13;
......@@ -147,11 +148,25 @@ private:
std::vector<std::string> params(std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>());
if (ADataType == "int8_t" and BDataType == "int8_t")
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "ck::half_t"; }))
{
params[params.size() - 3] = "8";
}
if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "float"; }))
{
params[params.size() - 3] = "4";
}
}
params[a_elementwise_op_idx] = AElementOp;
params[b_elementwise_op_idx] = BElementOp;
params[ds_layout_idx] = MakeLayoutTuple(DsLayout);
params[ds_data_type_idx] = MakeTypeTuple(DsDataType);
params[ds_elementwise_op_idx] = CDEElementOp;
params[e_data_type_idx] = EDataType;
auto block_size_str = params[block_size_idx];
auto m_per_block_str = params[m_per_block_idx];
auto n_per_block_str = params[n_per_block_idx];
......
......@@ -154,9 +154,9 @@ def get_int8_instances(src, file, template_name):
for key in aliases:
new_line = new_line.replace(key, aliases[key])
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Interwave"))
versions.append(new_line.replace("GemmPipeline", "ck:PipelineVersion::v2").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v1").replace("GemmLoopScheduler", "ck::LoopScheduler::Interwave"))
versions.append(new_line.replace("GemmPipeline", "ck::PipelineVersion::v2").replace("GemmLoopScheduler", "ck::LoopScheduler::Default"))
if "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor" in new_line:
instances["row_row"].extend(versions)
elif "ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor" in new_line:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment