Commit f52c2a4d authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Address PR comments.

parent e3d444c8
...@@ -14,65 +14,33 @@ ...@@ -14,65 +14,33 @@
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
// NOLINTNEXTLINE // NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__( const char* const content_wrapper = R"__ck__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content} ${content}
#pragma clang diagnostic pop )__ck__";
)__migraphx__";
template <class P> template <class P>
inline std::string ck_disable_warnings(P p) inline std::string ck_content_wrapper(P p)
{ {
return ck::host::InterpolateString(disable_warning_pragma, return ck::host::InterpolateString(content_wrapper,
{{"content", std::string{p.data(), p.size()}}}); {{"content", std::string{p.data(), p.size()}}});
} }
inline std::vector<rtc::src_file> create_headers_for_hiprtc_test() inline std::vector<rtc::src_file> create_headers_for_test()
{ {
auto ck_headers = ck::host::GetHeaders(); auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result; std::vector<rtc::src_file> result;
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) { std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
return rtc::src_file{p.first, ck_disable_warnings(p.second)}; return rtc::src_file{p.first, ck_content_wrapper(p.second)};
}); });
return result; return result;
} }
inline const std::vector<rtc::src_file>& get_headers_for_hiprtc_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_hiprtc_test();
return headers;
}
inline std::vector<rtc::src_file> create_headers_for_clang_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, {p.second.begin(), p.second.end()}};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_clang_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_clang_test();
return headers;
}
inline const std::vector<rtc::src_file>& get_headers_for_test() inline const std::vector<rtc::src_file>& get_headers_for_test()
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC))) static const std::vector<rtc::src_file> headers = create_headers_for_test();
{ return headers;
return get_headers_for_hiprtc_test();
}
else
{
return get_headers_for_clang_test();
}
} }
template <typename V> template <typename V>
......
...@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel) ...@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel)
{"m", std::to_string(prob.M)}, {"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)}, {"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}}); {"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options; rtc::compile_options options;
options.kernel_name = "f"; options.kernel_name = "f";
options.additional_src_files = get_headers_for_test(); auto k = rtc::compile_kernel(srcs, options);
auto k = rtc::compile_kernel(src, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize"); auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock"); auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock"); auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
......
...@@ -19,40 +19,11 @@ struct compile_options ...@@ -19,40 +19,11 @@ struct compile_options
{ {
std::string flags = ""; std::string flags = "";
std::string kernel_name = "main"; std::string kernel_name = "main";
std::vector<src_file> additional_src_files = {};
std::string params = "";
}; };
struct hip_compile_options kernel compile_kernel(const std::vector<src_file>& srcs,
{
std::size_t global;
std::size_t local;
std::string kernel_name = "kernel";
std::string params = "";
std::vector<src_file> additional_src_files = {};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void set_launch_params(const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
void set_launch_params(std::size_t default_global, std::size_t default_local = 1024)
{
set_launch_params([=](auto) { return default_global; }, default_local);
}
};
kernel compile_kernel(const std::vector<src_file>& src,
compile_options options = compile_options{}); compile_options options = compile_options{});
kernel compile_kernel(const std::string& content, compile_options options = compile_options{});
} // namespace rtc } // namespace rtc
#endif #endif
...@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str ...@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
} }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \ #define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg)) #define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template <class F, F f> // NOLINT struct hiprtc_program_destroy
struct manage_deleter
{ {
template <class T> void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); }
void operator()(T* x) const
{
if(x != nullptr)
{
(void)f(x);
}
}
}; };
template <class T, class F, F f> // NOLINT using hiprtc_program_ptr =
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>; std::unique_ptr<std::remove_pointer_t<hiprtcProgram>, hiprtc_program_destroy>;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts> template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs) hiprtc_program_ptr hiprtc_program_create(Ts... xs)
...@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs) ...@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
auto result = hiprtcCreateProgram(&prog, xs...); auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog}; hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS) if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Create program failed."); RTC_HIPRTC_THROW(result, "Create program failed.");
return p; return p;
} }
...@@ -252,11 +237,11 @@ struct hiprtc_program ...@@ -252,11 +237,11 @@ struct hiprtc_program
std::string log() const std::string log() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0) if(n == 0)
return {}; return {};
std::string buffer(n, '\0'); std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0); assert(buffer.back() != 0);
return buffer; return buffer;
} }
...@@ -264,108 +249,28 @@ struct hiprtc_program ...@@ -264,108 +249,28 @@ struct hiprtc_program
std::vector<char> get_code_obj() const std::vector<char> get_code_obj() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n); std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data())); RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
return buffer; return buffer;
} }
}; };
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs, std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src_file>& srcs,
const std::string& params, const compile_options& options)
const std::string& arch)
{ {
hiprtc_program prog(std::move(srcs)); hiprtc_program prog(srcs);
auto options = ck::host::SplitString(params, ' '); auto flags = ck::host::SplitString(options.flags, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1"); prog.compile(flags);
if(true)
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(true)
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back("-O3");
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
bool hip_has_flags(const std::vector<std::string>& flags) static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{ {
hiprtc_program prog{" "}; options.flags += " -I. -O3";
try options.flags += " -std=c++17";
{ options.flags += " --offload-arch=" + get_device_name();
prog.compile(flags, true); auto cos = compile_hip_src_with_hiprtc(srcs, options);
return true;
}
catch(...)
{
return false;
}
}
bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
static kernel hiprtc_compile_kernel(const std::string& content, compile_options options)
{
std::vector<src_file> srcs = options.additional_src_files;
srcs.push_back(src_file{std::string("main.cpp"), content});
options.params += " " + ck::host::JoinStrings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
if(cos.size() != 1) if(cos.size() != 1)
std::runtime_error("No code object"); std::runtime_error("No code object");
auto& obj = cos.front(); auto& obj = cos.front();
...@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options ...@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options
return kernel{obj.data(), options.kernel_name}; return kernel{obj.data(), options.kernel_name};
} }
kernel compile_kernel(const std::string& content, compile_options options) kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC))) if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{ {
return hiprtc_compile_kernel(content, options); return hiprtc_compile_kernel(srcs, options);
} }
else else
{ {
options.additional_src_files.push_back({"main.cpp", content}); return clang_compile_kernel(srcs, options);
return clang_compile_kernel(options.additional_src_files, options);
} }
} }
kernel compile_kernel(const std::vector<src_file>& src, compile_options options)
{
return clang_compile_kernel(src, options);
}
} // namespace rtc } // namespace rtc
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
#pragma once #pragma once
#include "ck/config.h" #include "ck/config.h"
#ifndef __HIPCC_RTC__ #ifndef __HIPCC_RTC__
#include "ck/utility/env.hpp"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h" #include "hip/hip_runtime.h"
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment