"tests/distributed/test_new_kvstore.py.bak" did not exist on "9bcce7beade62b12c23b1a7b7a169f5c06159ee2"
Commit f52c2a4d authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Address PR comments.

parent e3d444c8
......@@ -14,65 +14,33 @@
#include "ck/host/stringutils.hpp"
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
const char* const content_wrapper = R"__ck__(
${content}
#pragma clang diagnostic pop
)__migraphx__";
)__ck__";
template <class P>
inline std::string ck_disable_warnings(P p)
inline std::string ck_content_wrapper(P p)
{
return ck::host::InterpolateString(disable_warning_pragma,
return ck::host::InterpolateString(content_wrapper,
{{"content", std::string{p.data(), p.size()}}});
}
inline std::vector<rtc::src_file> create_headers_for_hiprtc_test()
inline std::vector<rtc::src_file> create_headers_for_test()
{
auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result;
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
return rtc::src_file{p.first, ck_disable_warnings(p.second)};
return rtc::src_file{p.first, ck_content_wrapper(p.second)};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_hiprtc_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_hiprtc_test();
return headers;
}
inline std::vector<rtc::src_file> create_headers_for_clang_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, {p.second.begin(), p.second.end()}};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_clang_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_clang_test();
return headers;
}
inline const std::vector<rtc::src_file>& get_headers_for_test()
{
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{
return get_headers_for_hiprtc_test();
}
else
{
return get_headers_for_clang_test();
}
static const std::vector<rtc::src_file> headers = create_headers_for_test();
return headers;
}
template <typename V>
......
......@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel)
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
options.kernel_name = "f";
options.additional_src_files = get_headers_for_test();
auto k = rtc::compile_kernel(src, options);
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
......
......@@ -19,40 +19,11 @@ struct compile_options
{
std::string flags = "";
std::string kernel_name = "main";
std::vector<src_file> additional_src_files = {};
std::string params = "";
};
struct hip_compile_options
{
std::size_t global;
std::size_t local;
std::string kernel_name = "kernel";
std::string params = "";
std::vector<src_file> additional_src_files = {};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void set_launch_params(const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
void set_launch_params(std::size_t default_global, std::size_t default_local = 1024)
{
set_launch_params([=](auto) { return default_global; }, default_local);
}
};
kernel compile_kernel(const std::vector<src_file>& src,
kernel compile_kernel(const std::vector<src_file>& srcs,
compile_options options = compile_options{});
kernel compile_kernel(const std::string& content, compile_options options = compile_options{});
} // namespace rtc
#endif
......@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
}
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template <class F, F f> // NOLINT
struct manage_deleter
struct hiprtc_program_destroy
{
template <class T>
void operator()(T* x) const
{
if(x != nullptr)
{
(void)f(x);
}
}
void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); }
};
template <class T, class F, F f> // NOLINT
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
using hiprtc_program_ptr =
std::unique_ptr<std::remove_pointer_t<hiprtcProgram>, hiprtc_program_destroy>;
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
......@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
RTC_HIPRTC_THROW(result, "Create program failed.");
return p;
}
......@@ -252,11 +237,11 @@ struct hiprtc_program
std::string log() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0)
return {};
std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0);
return buffer;
}
......@@ -264,108 +249,28 @@ struct hiprtc_program
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
const std::string& params,
const std::string& arch)
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src_file>& srcs,
const compile_options& options)
{
hiprtc_program prog(std::move(srcs));
auto options = ck::host::SplitString(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
if(true)
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(true)
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back("-O3");
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
hiprtc_program prog(srcs);
auto flags = ck::host::SplitString(options.flags, ' ');
prog.compile(flags);
return {prog.get_code_obj()};
}
bool hip_has_flags(const std::vector<std::string>& flags)
static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
static kernel hiprtc_compile_kernel(const std::string& content, compile_options options)
{
std::vector<src_file> srcs = options.additional_src_files;
srcs.push_back(src_file{std::string("main.cpp"), content});
options.params += " " + ck::host::JoinStrings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
options.flags += " -I. -O3";
options.flags += " -std=c++17";
options.flags += " --offload-arch=" + get_device_name();
auto cos = compile_hip_src_with_hiprtc(srcs, options);
if(cos.size() != 1)
std::runtime_error("No code object");
auto& obj = cos.front();
......@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options
return kernel{obj.data(), options.kernel_name};
}
kernel compile_kernel(const std::string& content, compile_options options)
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{
return hiprtc_compile_kernel(content, options);
return hiprtc_compile_kernel(srcs, options);
}
else
{
options.additional_src_files.push_back({"main.cpp", content});
return clang_compile_kernel(options.additional_src_files, options);
return clang_compile_kernel(srcs, options);
}
}
kernel compile_kernel(const std::vector<src_file>& src, compile_options options)
{
return clang_compile_kernel(src, options);
}
} // namespace rtc
......@@ -4,7 +4,9 @@
#pragma once
#include "ck/config.h"
#ifndef __HIPCC_RTC__
#include "ck/utility/env.hpp"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment