Commit b57f58e1 authored by Paul's avatar Paul
Browse files

Move to cpp file

parent 873f6c0c
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
...@@ -38,13 +39,17 @@ ...@@ -38,13 +39,17 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include "ck_gemm_instances.hpp"
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
...@@ -95,6 +100,15 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t ...@@ -95,6 +100,15 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
return int_div_ceil(m, mpb) * int_div_ceil(n, npb); return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
} }
template<class F, class Action>
auto action_decorate(F f, Action action)
{
return [=](auto&&... xs) {
action();
f(std::forward<decltype(xs)>(xs)...);
};
}
struct ck_gemm_compiler : compiler<ck_gemm_compiler> struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
...@@ -153,7 +167,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -153,7 +167,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, op.to_value())), [=] {
if (enabled(MIGRAPHX_LOG_CK_GEMM{}))
std::cout << "ck_gemm: " << to_json_string(to_value(shapes)) << std::endl;
});
} }
}; };
......
#ifndef MIGRAPHX_GUARD_JIT_CK_INSTANCES_HPP
#define MIGRAPHX_GUARD_JIT_CK_INSTANCES_HPP
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <string> #include <string>
#include <functional> #include <functional>
inline const std::vector<std::string>& const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred) get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred)
{ {
static std::vector<std::vector<std::vector<std::string>>> instances = { static std::vector<std::vector<std::vector<std::string>>> instances = {
...@@ -8051,5 +8048,3 @@ get_instance(std::size_t i, const std::function<bool(const std::vector<std::stri ...@@ -8051,5 +8048,3 @@ get_instance(std::size_t i, const std::function<bool(const std::vector<std::stri
std::find_if(instances.begin(), instances.end(), [&](const auto& v) { return pred(v[0]); }); std::find_if(instances.begin(), instances.end(), [&](const auto& v) { return pred(v[0]); });
return it->at(i); return it->at(i);
} }
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment