Commit 830dff7a authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent d46c7224
...@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -167,7 +167,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
auto path = fs::path{"migraphx"} / "kernels" / name; auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c}; return src_file{path, c};
}); });
if (not options.embedded_headers.empty()) if(not options.embedded_headers.empty())
{ {
std::transform(options.embedded_headers.begin(), std::transform(options.embedded_headers.begin(),
options.embedded_headers.end(), options.embedded_headers.end(),
......
...@@ -42,8 +42,7 @@ struct hip_compile_options ...@@ -42,8 +42,7 @@ struct hip_compile_options
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
std::unordered_map<std::string, std::pair<const char*,const char*>> embedded_headers; std::unordered_map<std::string, std::pair<const char*, const char*>> embedded_headers;
/** /**
* @brief Set the launch parameters but allow v to override the values * @brief Set the launch parameters but allow v to override the values
......
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
#include "ck/include/device_gemm_multiple_d.hpp" #include "ck/include/device_gemm_multiple_d.hpp"
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params}) ...@@ -79,8 +78,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action> template <class F, class Action>
...@@ -244,9 +241,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -244,9 +241,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto e_type = get_type(c_shape); const auto e_type = get_type(c_shape);
std::vector<bool> ds_layout; std::vector<bool> ds_layout;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_layout), [](const auto& i){ return transposed_matrix(i); }); std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_layout),
[](const auto& i) { return transposed_matrix(i); });
std::vector<std::string> ds_type; std::vector<std::string> ds_type;
std::transform(inputs.begin() + 2, inputs.end() - 1, std::back_inserter(ds_type), [](const auto& i){ return get_type(i); }); std::transform(inputs.begin() + 2,
inputs.end() - 1,
std::back_inserter(ds_type),
[](const auto& i) { return get_type(i); });
std::string ck_passthrough = "ck_passthrough"; std::string ck_passthrough = "ck_passthrough";
std::string cde_op = ck_passthrough; std::string cde_op = ck_passthrough;
...@@ -256,9 +259,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -256,9 +259,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op = v.at("post").to<std::string>(); cde_op = v.at("post").to<std::string>();
} }
auto problem = auto problem = ck::tensor_operation::device::device_gemm_multiple_d::Problem{
ck::tensor_operation::device::device_gemm_multiple_d:: static_cast<ck::index_t>(m),
Problem{static_cast<ck::index_t>(m),
static_cast<ck::index_t>(n), static_cast<ck::index_t>(n),
static_cast<ck::index_t>(k), static_cast<ck::index_t>(k),
transA, transA,
...@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -281,7 +283,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto blocks_per_batch = solution.grid_size; const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size; const auto block_size = solution.block_size;
hip_compile_options options; hip_compile_options options;
options.embedded_headers = ck_headers; options.embedded_headers = ck_headers;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment