Commit 96e062d1 authored by Paul's avatar Paul
Browse files

Add additional src files

parent 36154263
...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -161,7 +161,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(not options.inputs.empty()); assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
std::back_inserter(srcs), std::back_inserter(srcs),
......
...@@ -29,6 +29,9 @@ ...@@ -29,6 +29,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct src_file;
namespace gpu { namespace gpu {
struct context; struct context;
...@@ -42,6 +45,7 @@ struct hip_compile_options ...@@ -42,6 +45,7 @@ struct hip_compile_options
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
std::vector<src_file> additional_src_files = {};
/** /**
* @brief Set the launch parameters but allow v to override the values * @brief Set the launch parameters but allow v to override the values
......
...@@ -80,15 +80,6 @@ __global__ void ${kernel}(${params}) ...@@ -80,15 +80,6 @@ __global__ void ${kernel}(${params})
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
template <class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
using tuning_entry = std::pair<std::vector<shape>, size_t>; using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s) static std::vector<tuning_entry> read_tuning(const std::string& s)
{ {
...@@ -300,7 +291,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -300,7 +291,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto block_size = solution.block_size; const auto block_size = solution.block_size;
hip_compile_options options; hip_compile_options options;
options.embedded_headers = ck_headers; std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(options.additional_src_files), [&](auto&& p) {
return src_file{fs::path{p.first}, p.second};
});
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment