Commit b57f58e1 authored by Paul's avatar Paul
Browse files

Move to cpp file

parent 873f6c0c
......@@ -30,6 +30,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
......@@ -38,13 +39,17 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp>
#include "ck_gemm_instances.hpp"
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
......@@ -95,6 +100,15 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
template<class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
static std::string get_layout(const shape& s)
......@@ -153,7 +167,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, op.to_value())), [=] {
if (enabled(MIGRAPHX_LOG_CK_GEMM{}))
std::cout << "ck_gemm: " << to_json_string(to_value(shapes)) << std::endl;
});
}
};
......
#ifndef MIGRAPHX_GUARD_JIT_CK_INSTANCES_HPP
#define MIGRAPHX_GUARD_JIT_CK_INSTANCES_HPP
#include <algorithm>
#include <vector>
#include <string>
#include <functional>
inline const std::vector<std::string>&
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred)
{
static std::vector<std::vector<std::vector<std::string>>> instances = {
......@@ -8051,5 +8048,3 @@ get_instance(std::size_t i, const std::function<bool(const std::vector<std::stri
std::find_if(instances.begin(), instances.end(), [&](const auto& v) { return pred(v[0]); });
return it->at(i);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment