"vscode:/vscode.git/clone" did not exist on "a0addb6114d7f66cf9abc0f06592259b09abf68d"
Commit d3a96e51 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Replace ENV variable with CMake option for toggling hipRTC in codegen

tests.
parent 47486abb
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
add_compile_definitions(HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
......
#pragma once
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iterator>
#include <numeric>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <unordered_set>
#include "ck/host/headers.hpp"
#include "rtc/hiprtc_enable_env.hpp"
#include "ck/host/stringutils.hpp"
// NOLINTNEXTLINE
const char* const content_wrapper = R"__ck__(
const char* const ck_content_wrapper = R"__ck__(
${content}
)__ck__";
template <class P>
inline std::string ck_content_wrapper(P p)
inline std::string content_wrapper(P p)
{
return ck::host::InterpolateString(content_wrapper,
return ck::host::InterpolateString(ck_content_wrapper,
{{"content", std::string{p.data(), p.size()}}});
}
......@@ -29,11 +29,9 @@ inline std::vector<rtc::src_file> create_headers_for_test()
{
auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result;
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
return rtc::src_file{p.first, ck_content_wrapper(p.second)};
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
return rtc::src_file{p.first, content_wrapper(p.second)};
});
return result;
}
......@@ -83,7 +81,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
});
}
std::string classify(double x)
inline std::string classify(double x)
{
switch(std::fpclassify(x))
{
......
#include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iterator>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
using half = _Float16;
// using half = __fp16;
const std::string gemm_compile_check = R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
constexpr auto desc =
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
${n})));
constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
......
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#include <rtc/kernel.hpp>
#include <ck/filesystem.hpp>
#include <string>
#include <rtc/kernel.hpp>
#include <functional>
#include <string>
namespace rtc {
......
......@@ -3,8 +3,8 @@
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <stdexcept>
#include <string>
namespace rtc {
......
#include <ck/utility/env.hpp>
CK_DECLARE_ENV_VAR_BOOL(CK_CODEGEN_TESTS_ENABLE_HIPRTC)
\ No newline at end of file
#include "rtc/hip.hpp"
#include <ck/host/stringutils.hpp>
#include <rtc/compile_kernel.hpp>
// TODO include only if USE_RTC is set?
#include <rtc/hip.hpp>
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#include <hip/hiprtc.h>
#endif
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <cassert>
#include <numeric>
#include <deque>
#include <rtc/hiprtc_enable_env.hpp>
#include <ck/host/stringutils.hpp>
#include <fstream>
#include <iostream>
#include <numeric>
#include <stdexcept>
namespace rtc {
......@@ -106,6 +106,8 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
return kernel{obj.data(), options.kernel_name};
}
#ifdef HIPRTC_FOR_CODEGEN_TESTS
struct hiprtc_src_file
{
hiprtc_src_file() = default;
......@@ -274,20 +276,18 @@ static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_o
if(cos.size() != 1)
std::runtime_error("No code object");
auto& obj = cos.front();
return kernel{obj.data(), options.kernel_name};
}
#endif
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
{
#ifdef HIPRTC_FOR_CODEGEN_TESTS
return hiprtc_compile_kernel(srcs, options);
}
else
{
#else
return clang_compile_kernel(srcs, options);
}
#endif
}
} // namespace rtc
......@@ -340,8 +340,8 @@ struct Bilinear
};
template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t>(
int8_t& y, const int32_t& x0, const int8_t& x1) const
__host__ __device__ constexpr void
operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
{
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
......
......@@ -113,7 +113,6 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
return static_cast<T&&>(t_);
}
// TODO
template<class T> struct is_const : false_type {};
template<class T> struct is_const<const T> : true_type {};
template< class T >
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment