"vscode:/vscode.git/clone" did not exist on "34bd1a7083e5875e6a4b2d4f61c0b356cc5d53fc"
Commit 3b2a7aee authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Address PR comments.

parent d3a96e51
......@@ -100,33 +100,5 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
return result;
}
inline bool StartsWith(const std::string& value, const std::string& prefix)
{
if(prefix.size() > value.size())
return false;
else
return std::equal(prefix.begin(), prefix.end(), value.begin());
}
inline bool EndsWith(const std::string& value, const std::string& suffix)
{
if(suffix.size() > value.size())
return false;
else
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
}
inline std::vector<std::string> SplitString(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + delim);
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
} // namespace host
} // namespace ck
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
add_compile_definitions(HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
......
......@@ -2,3 +2,9 @@ file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
add_library(ck_rtc ${RTC_SOURCES})
target_include_directories(ck_rtc PUBLIC include)
target_link_libraries(ck_rtc PUBLIC hip::host)
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()
#include <ck/host/stringutils.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#include <hip/hiprtc.h>
#include <rtc/manage_ptr.hpp>
#endif
#include <rtc/tmp_dir.hpp>
#include <cassert>
......@@ -14,6 +14,26 @@
namespace rtc {
bool EndsWith(const std::string& value, const std::string& suffix)
{
if(suffix.size() > value.size())
return false;
else
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
}
std::vector<std::string> SplitString(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + delim);
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
template <class T>
T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{
......@@ -108,42 +128,27 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
#ifdef HIPRTC_FOR_CODEGEN_TESTS
struct hiprtc_src_file
{
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
std::string path;
std::string content;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.path, "path"), f(self.content, "content"));
}
};
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
}
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
void hiprtc_check_error(hiprtcResult err, const std::string& msg = "")
{
if(err != HIPRTC_SUCCESS)
throw std::runtime_error(hiprtc_error(err, msg));
}
// NOLINTNEXTLINE
#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
struct hiprtc_program_destroy
struct hiprtc_src_file
{
void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); }
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
std::string path;
std::string content;
};
using hiprtc_program_ptr =
std::unique_ptr<std::remove_pointer_t<hiprtcProgram>, hiprtc_program_destroy>;
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
......@@ -151,8 +156,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
hiprtcProgram prog = nullptr;
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS)
RTC_HIPRTC_THROW(result, "Create program failed.");
hiprtc_check_error(result, "Create program failed.");
return p;
}
......@@ -193,7 +197,7 @@ struct hiprtc_program
{
for(auto&& src : srcs)
{
if(ck::host::EndsWith(src.path, ".cpp"))
if(EndsWith(src.path, ".cpp"))
{
cpp_src = std::move(src.content);
cpp_name = std::move(src.path);
......@@ -239,11 +243,11 @@ struct hiprtc_program
std::string log() const
{
std::size_t n = 0;
RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0)
return {};
std::string buffer(n, '\0');
RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0);
return buffer;
}
......@@ -251,9 +255,9 @@ struct hiprtc_program
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
......@@ -262,7 +266,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src
const compile_options& options)
{
hiprtc_program prog(srcs);
auto flags = ck::host::SplitString(options.flags, ' ');
auto flags = SplitString(options.flags, ' ');
prog.compile(flags);
return {prog.get_code_obj()};
}
......
......@@ -9,6 +9,7 @@
namespace ck {
#ifdef __HIPCC_RTC__
template <bool B>
using bool_constant = integral_constant<bool, B>;
......@@ -113,51 +114,75 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
return static_cast<T&&>(t_);
}
template<class T> struct is_const : false_type {};
template<class T> struct is_const<const T> : true_type {};
template< class T >
template <class T>
struct is_const : false_type
{
};
template <class T>
struct is_const<const T> : true_type
{
};
template <class T>
inline constexpr bool is_const_v = is_const<T>::value;
template< class T >
template <class T>
inline constexpr bool is_reference_v = is_reference<T>::value;
template<class T> struct remove_const { typedef T type; };
template<class T> struct remove_const<const T> { typedef T type; };
template< class T >
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
using remove_const_t = typename remove_const<T>::type;
template< class T >
template <class T>
inline constexpr bool is_class_v = is_class<T>::value;
template< class T >
template <class T>
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
template< class... >
template <class...>
using void_t = void;
using __hip::declval;
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::false_type;
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_class_v;
using std::is_const_v;
using std::is_pointer;
using std::is_reference;
using std::is_reference_v;
using std::is_trivially_copyable;
using std::is_trivially_copyable_v;
using std::is_unsigned;
using std::remove_const_t;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::is_const_v;
using std::is_reference_v;
using std::remove_const_t;
using std::is_class_v;
using std::is_trivially_copyable_v;
using std::void_t;
using std::false_type;
using std::true_type;
using std::declval;
using std::void_t;
#endif
template <typename X, typename Y>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment