Commit 830dff7a authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d46c7224
......@@ -167,19 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c};
});
if (not options.embedded_headers.empty())
if(not options.embedded_headers.empty())
{
std::transform(options.embedded_headers.begin(),
options.embedded_headers.end(),
std::back_inserter(srcs),
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c};
});
options.embedded_headers.end(),
std::back_inserter(srcs),
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c};
});
}
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp =
......
......@@ -42,8 +42,7 @@ struct hip_compile_options
std::string kernel_name = "kernel";
std::string params = "";
std::vector<shape> virtual_inputs = {};
std::unordered_map<std::string, std::pair<const char*,const char*>> embedded_headers;
std::unordered_map<std::string, std::pair<const char*, const char*>> embedded_headers;
/**
* @brief Set the launch parameters but allow v to override the values
......
......@@ -40,7 +40,6 @@
#include "ck/include/device_gemm_multiple_d.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
)__migraphx__";
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action>
......@@ -237,41 +234,46 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto n = c_shape.lens().back();
auto k = a_shape.lens().back();
const bool transA = transposed_matrix(a_shape);
const bool transB = transposed_matrix(b_shape);
const bool transE = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape);
const bool transA = transposed_matrix(a_shape);
const bool transB = transposed_matrix(b_shape);
const bool transE = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_layout), [](const auto& i){ return transposed_matrix(i); });
std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_layout),
[](const auto& i) { return transposed_matrix(i); });
std::vector<std::string> ds_type;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_type), [](const auto& i){ return get_type(i); });
std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_type),
[](const auto& i) { return get_type(i); });
std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post"))
{
cde_op = v.at("post").to<std::string>();
}
auto problem =
ck::tensor_operation::device::device_gemm_multiple_d::
Problem{static_cast<ck::index_t>(m),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(k),
transA,
transB,
transE,
ds_layout,
a_type,
b_type,
e_type,
ds_type,
ck_passthrough,
ck_passthrough,
cde_op};
auto problem = ck::tensor_operation::device::device_gemm_multiple_d::Problem{
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(k),
transA,
transB,
transE,
ds_layout,
a_type,
b_type,
e_type,
ds_type,
ck_passthrough,
ck_passthrough,
cde_op};
const auto include_header = problem.GetIncludeHeader();
const auto ck_headers = problem.GetHeaders();
......@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
hip_compile_options options;
options.embedded_headers = ck_headers;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment