Commit 830dff7a authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d46c7224
......@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c};
});
if (not options.embedded_headers.empty())
if(not options.embedded_headers.empty())
{
std::transform(options.embedded_headers.begin(),
options.embedded_headers.end(),
......
......@@ -42,8 +42,7 @@ struct hip_compile_options
std::string kernel_name = "kernel";
std::string params = "";
std::vector<shape> virtual_inputs = {};
std::unordered_map<std::string, std::pair<const char*,const char*>> embedded_headers;
std::unordered_map<std::string, std::pair<const char*, const char*>> embedded_headers;
/**
* @brief Set the launch parameters but allow v to override the values
......
......@@ -40,7 +40,6 @@
#include "ck/include/device_gemm_multiple_d.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
)__migraphx__";
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action>
......@@ -244,9 +241,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_layout), [](const auto& i){ return transposed_matrix(i); });
std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_layout),
[](const auto& i) { return transposed_matrix(i); });
std::vector<std::string> ds_type;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_type), [](const auto& i){ return get_type(i); });
std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_type),
[](const auto& i) { return get_type(i); });
std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough;
......@@ -256,9 +259,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op = v.at("post").to<std::string>();
}
auto problem =
ck::tensor_operation::device::device_gemm_multiple_d::
Problem{static_cast<ck::index_t>(m),
auto problem = ck::tensor_operation::device::device_gemm_multiple_d::Problem{
static_cast<ck::index_t>(m),
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(k),
transA,
......@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
hip_compile_options options;
options.embedded_headers = ck_headers;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment