Commit 08255e1b authored by Dino Musić's avatar Dino Musić
Browse files

Implement hiprtc for codegen tests

parent 1658c0dc
cmake_minimum_required(VERSION 3.16)
project(composable_kernel_host LANGUAGES CXX HIP)
find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
add_compile_options(-std=c++17)
find_package(hip)
add_custom_target(codegen)
# add include directories
include_directories(BEFORE
${PROJECT_BINARY_DIR}/include
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/library/include
${HIP_INCLUDE_DIRS}
${CK_ROOT}/include/
)
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
......@@ -25,29 +32,31 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
##message(STATUS "SOURCE_FILES: ${SOURCES}")
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
target_link_libraries(ck_host PRIVATE ck_headers)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
)
add_executable(ck-template-driver driver/main.cpp)
target_link_libraries(ck-template-driver ck_host)
rocm_install(
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
add_subdirectory(test)
add_subdirectory(test)
\ No newline at end of file
......@@ -100,5 +100,33 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
return result;
}
inline bool StartsWith(const std::string& value, const std::string& prefix)
{
if(prefix.size() > value.size())
return false;
else
return std::equal(prefix.begin(), prefix.end(), value.begin());
}
inline bool EndsWith(const std::string& value, const std::string& suffix)
{
if(suffix.size() > value.size())
return false;
else
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
}
inline std::vector<std::string> SplitString(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + delim);
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
} // namespace host
} // namespace ck
......@@ -8,18 +8,73 @@
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <unordered_set>
#include "ck/host/headers.hpp"
#include "rtc/hiprtc_enable_env.hpp"
#include "ck/host/stringutils.hpp"
std::vector<rtc::src_file> get_headers_for_test()
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
inline std::string ck_disable_warnings(P p)
{
return ck::host::InterpolateString(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
inline std::vector<rtc::src_file> create_headers_for_hiprtc_test()
{
auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result;
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
return rtc::src_file{p.first, ck_disable_warnings(p.second)};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_hiprtc_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_hiprtc_test();
return headers;
}
inline std::vector<rtc::src_file> create_headers_for_clang_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
return {p.first, {p.second.begin(), p.second.end()}};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_clang_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_clang_test();
return headers;
}
inline const std::vector<rtc::src_file>& get_headers_for_test()
{
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{
return get_headers_for_hiprtc_test();
}
else
{
return get_headers_for_clang_test();
}
}
template <typename V>
std::size_t GetSize(V mLens, V mStrides)
{
......@@ -34,18 +89,24 @@ std::size_t GetSize(V mLens, V mStrides)
return space;
}
template <class T, typename V>
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
rtc::buffer<T> result(space);
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
// std::fill(result.begin(), result.end(), 1);
return result;
}
template <class T, typename V>
std::enable_if_t<!std::is_integral_v<V>, rtc::buffer<T>>
generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
return generate_buffer<T>(space, seed);
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
......
#include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
......@@ -15,116 +16,6 @@
using half = _Float16;
// using half = __fp16;
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
return result;
}
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
return result;
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
return fabs(x - y) < atol + rtol * fabs(y);
});
}
std::string classify(double x)
{
switch(std::fpclassify(x))
{
case FP_INFINITE: return "inf";
case FP_NAN: return "nan";
case FP_NORMAL: return "normal";
case FP_SUBNORMAL: return "subnormal";
case FP_ZERO: return "zero";
default: return "unknown";
}
}
template <class Buffer>
void print_classification(const Buffer& x)
{
std::unordered_set<std::string> result;
for(const auto& i : x)
result.insert(classify(i));
for(const auto& c : result)
std::cout << c << ", ";
std::cout << std::endl;
}
template <class Buffer>
void print_statistics(const Buffer& x)
{
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
double num_elements = x.size();
auto mean =
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
auto stddev = std::sqrt(
std::accumulate(x.begin(),
x.end(),
double{0.0},
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
std::cout << "Mean: " << mean << ", ";
std::cout << "StdDev: " << stddev << "\n";
}
template <class Buffer>
void print_preview(const Buffer& x)
{
if(x.size() <= 10)
{
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
}
else
{
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
std::cout << "..., ";
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
}
std::cout << std::endl;
}
template <class T>
struct check_all
{
rtc::buffer<T> data{};
bool operator()(const rtc::buffer<T>& x)
{
if(data.empty())
{
data = x;
return true;
}
if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); }))
return false;
return allclose(data, x);
}
};
template <class Solution>
auto report(const Solution& solution, bool pass)
{
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
}
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
......@@ -163,23 +54,28 @@ TEST_CASE(test_problem_kernel)
std::string epilogue = "";
std::string prologue = "";
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
rtc::compile_options options;
options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
options.kernel_name = "f";
options.additional_src_files = get_headers_for_test();
auto k = rtc::compile_kernel(src, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
......
......@@ -4,24 +4,29 @@
#include <rtc/kernel.hpp>
#include <ck/filesystem.hpp>
#include <string>
#include <functional>
namespace rtc {
struct src_file
{
CK::fs::path path;
std::string_view content;
std::string content;
};
struct compile_options
{
std::string flags = "";
std::string kernel_name = "main";
std::vector<src_file> additional_src_files = {};
std::string params = "";
};
kernel compile_kernel(const std::vector<src_file>& src,
compile_options options = compile_options{});
kernel compile_kernel(const std::string& content, compile_options options = compile_options{});
} // namespace rtc
#endif
......@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <stdexcept>
namespace rtc {
......
#include <ck/utility/env.hpp>
CK_DECLARE_ENV_VAR_BOOL(CK_CODEGEN_TESTS_ENABLE_HIPRTC)
\ No newline at end of file
#include "rtc/hip.hpp"
#include <rtc/compile_kernel.hpp>
// TODO include only if USE_RTC is set?
#include <hip/hiprtc.h>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <cassert>
#include <numeric>
#include <deque>
#include <rtc/hiprtc_enable_env.hpp>
#include <ck/host/stringutils.hpp>
namespace rtc {
......@@ -59,7 +65,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device
// TODO: undo after extracting the codeobj
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
assert(not srcs.empty());
tmp_dir td{"compile"};
......@@ -100,4 +106,289 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
return kernel{obj.data(), options.kernel_name};
}
struct hiprtc_src_file
{
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
std::string path;
std::string content;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.path, "path"), f(self.content, "content"));
}
};
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
}
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
{
if(err != HIPRTC_SUCCESS)
throw std::runtime_error(hiprtc_error(err, msg));
}
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template <class F, F f> // NOLINT
struct manage_deleter
{
template <class T>
void operator()(T* x) const
{
if(x != nullptr)
{
(void)f(x);
}
}
};
template <class T, class F, F f> // NOLINT
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
{
hiprtcProgram prog = nullptr;
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
return p;
}
struct hiprtc_program
{
struct string_array
{
std::deque<std::string> strings{};
std::vector<const char*> c_strs{};
string_array() {}
string_array(const string_array&) = delete;
std::size_t size() const { return strings.size(); }
const char** data() { return c_strs.data(); }
void push_back(std::string s)
{
strings.push_back(std::move(s));
c_strs.push_back(strings.back().c_str());
}
};
hiprtc_program_ptr prog = nullptr;
string_array headers{};
string_array include_names{};
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<src_file> srcs)
{
for(auto&& src : srcs)
{
if(ck::host::EndsWith(src.path, ".cpp"))
{
cpp_src = std::move(src.content);
cpp_name = std::move(src.path);
}
else
{
headers.push_back(std::string(src.content.begin(), src.content.end()));
include_names.push_back(std::move(src.path));
}
}
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
headers.data(),
include_names.data());
}
void compile(const std::vector<std::string>& options, bool quiet = false) const
{
std::vector<const char*> c_options;
std::transform(options.begin(),
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty() and not quiet)
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS)
throw std::runtime_error("Compilation failed.");
}
std::string log() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0)
return {};
std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0);
return buffer;
}
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
const std::string& params,
const std::string& arch)
{
hiprtc_program prog(std::move(srcs));
auto options = ck::host::SplitString(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
if(true)
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(true)
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back("-O3");
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
static kernel hiprtc_compile_kernel(const std::string& content, compile_options options)
{
std::vector<src_file> srcs = options.additional_src_files;
srcs.push_back(src_file{std::string("main.cpp"), content});
options.params += " " + ck::host::JoinStrings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
if(cos.size() != 1)
std::runtime_error("No code object");
auto& obj = cos.front();
return kernel{obj.data(), options.kernel_name};
}
kernel compile_kernel(const std::string& content, compile_options options)
{
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{
return hiprtc_compile_kernel(content, options);
}
else
{
options.additional_src_files.push_back({"main.cpp", content});
return clang_compile_kernel(options.additional_src_files, options);
}
}
kernel compile_kernel(const std::vector<src_file>& src, compile_options options)
{
return clang_compile_kernel(src, options);
}
} // namespace rtc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment