Commit 830dff7a authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d46c7224
...@@ -167,19 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -167,19 +167,19 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c}; return src_file{path, c};
}); });
if (not options.embedded_headers.empty()) if(not options.embedded_headers.empty())
{ {
std::transform(options.embedded_headers.begin(), std::transform(options.embedded_headers.begin(),
options.embedded_headers.end(), options.embedded_headers.end(),
std::back_inserter(srcs), std::back_inserter(srcs),
[](auto&& p) { [](auto&& p) {
auto&& name = p.first; auto&& name = p.first;
auto&& c = p.second; auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c}; return src_file{path, c};
}); });
} }
srcs.push_back(src_file{fs::path{"main.cpp"}, srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())}); std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
......
...@@ -42,8 +42,7 @@ struct hip_compile_options ...@@ -42,8 +42,7 @@ struct hip_compile_options
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
std::unordered_map<std::string, std::pair<const char*,const char*>> embedded_headers; std::unordered_map<std::string, std::pair<const char*, const char*>> embedded_headers;
/** /**
* @brief Set the launch parameters but allow v to override the values * @brief Set the launch parameters but allow v to override the values
......
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
#include "ck/include/device_gemm_multiple_d.hpp" #include "ck/include/device_gemm_multiple_d.hpp"
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params}) ...@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action> template <class F, class Action>
...@@ -237,41 +234,46 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -237,41 +234,46 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
const bool transA = transposed_matrix(a_shape); const bool transA = transposed_matrix(a_shape);
const bool transB = transposed_matrix(b_shape); const bool transB = transposed_matrix(b_shape);
const bool transE = transposed_matrix(c_shape); const bool transE = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape); const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape); const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout; std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_layout), [](const auto& i){ return transposed_matrix(i); }); std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_layout),
[](const auto& i) { return transposed_matrix(i); });
std::vector<std::string> ds_type; std::vector<std::string> ds_type;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_type), [](const auto& i){ return get_type(i); }); std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_type),
[](const auto& i) { return get_type(i); });
std::string ck_passthrough = "ck_passthrough"; std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough; std::string cde_op = ck_passthrough;
assert(inputs.size() < 4 or v.contains("post")); assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post")) if(v.contains("post"))
{ {
cde_op = v.at("post").to<std::string>(); cde_op = v.at("post").to<std::string>();
} }
auto problem = auto problem = ck::tensor_operation::device::device_gemm_multiple_d::Problem{
ck::tensor_operation::device::device_gemm_multiple_d:: static_cast<ck::index_t>(m),
Problem{static_cast<ck::index_t>(m), static_cast<ck::index_t>(n),
static_cast<ck::index_t>(n), static_cast<ck::index_t>(k),
static_cast<ck::index_t>(k), transA,
transA, transB,
transB, transE,
transE, ds_layout,
ds_layout, a_type,
a_type, b_type,
b_type, e_type,
e_type, ds_type,
ds_type, ck_passthrough,
ck_passthrough, ck_passthrough,
ck_passthrough, cde_op};
cde_op};
const auto include_header = problem.GetIncludeHeader(); const auto include_header = problem.GetIncludeHeader();
const auto ck_headers = problem.GetHeaders(); const auto ck_headers = problem.GetHeaders();
...@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto blocks_per_batch = solution.grid_size; const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size; const auto block_size = solution.block_size;
hip_compile_options options; hip_compile_options options;
options.embedded_headers = ck_headers; options.embedded_headers = ck_headers;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment