Commit d3a96e51 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Replace ENV variable with CMake option for toggling hipRTC in codegen

tests.
parent 47486abb
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
add_compile_definitions(HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc) add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
......
#pragma once #pragma once
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <fstream>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <random> #include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <unordered_set> #include <unordered_set>
#include "ck/host/headers.hpp"
#include "rtc/hiprtc_enable_env.hpp"
#include "ck/host/stringutils.hpp"
// NOLINTNEXTLINE // NOLINTNEXTLINE
const char* const content_wrapper = R"__ck__( const char* const ck_content_wrapper = R"__ck__(
${content} ${content}
)__ck__"; )__ck__";
template <class P> template <class P>
inline std::string ck_content_wrapper(P p) inline std::string content_wrapper(P p)
{ {
return ck::host::InterpolateString(content_wrapper, return ck::host::InterpolateString(ck_content_wrapper,
{{"content", std::string{p.data(), p.size()}}}); {{"content", std::string{p.data(), p.size()}}});
} }
...@@ -29,11 +29,9 @@ inline std::vector<rtc::src_file> create_headers_for_test() ...@@ -29,11 +29,9 @@ inline std::vector<rtc::src_file> create_headers_for_test()
{ {
auto ck_headers = ck::host::GetHeaders(); auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result; std::vector<rtc::src_file> result;
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) { return rtc::src_file{p.first, content_wrapper(p.second)};
return rtc::src_file{p.first, ck_content_wrapper(p.second)};
}); });
return result; return result;
} }
...@@ -83,7 +81,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) ...@@ -83,7 +81,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
}); });
} }
std::string classify(double x) inline std::string classify(double x)
{ {
switch(std::fpclassify(x)) switch(std::fpclassify(x))
{ {
......
#include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp" #include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp" #include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp" #include "ck/host/utils.hpp"
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <fstream>
#include <iterator> #include <iterator>
#include <random> #include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
using half = _Float16; using half = _Float16;
// using half = __fp16;
const std::string gemm_compile_check = R"__ck__( const std::string gemm_compile_check = R"__ck__(
#include <${include}> #include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template}; using G = ${template};
constexpr auto desc = constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
${k})), ck::make_tuple(),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
${n})));
static_assert(desc.IsValid(), "Invalid ck gemm."); static_assert(desc.IsValid(), "Invalid ck gemm.");
...@@ -69,15 +65,15 @@ TEST_CASE(test_problem_kernel) ...@@ -69,15 +65,15 @@ TEST_CASE(test_problem_kernel)
{"m", std::to_string(prob.M)}, {"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)}, {"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}}); {"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test(); auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src}); srcs.push_back({"main.cpp", src});
rtc::compile_options options; rtc::compile_options options;
options.kernel_name = "f"; options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options); auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize"); auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock"); auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock"); auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block); ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
......
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#include <rtc/kernel.hpp>
#include <ck/filesystem.hpp> #include <ck/filesystem.hpp>
#include <string> #include <rtc/kernel.hpp>
#include <functional> #include <functional>
#include <string>
namespace rtc { namespace rtc {
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <memory> #include <memory>
#include <string>
#include <stdexcept> #include <stdexcept>
#include <string>
namespace rtc { namespace rtc {
......
#include <ck/utility/env.hpp>
CK_DECLARE_ENV_VAR_BOOL(CK_CODEGEN_TESTS_ENABLE_HIPRTC)
\ No newline at end of file
#include "rtc/hip.hpp" #include <ck/host/stringutils.hpp>
#include <rtc/compile_kernel.hpp> #include <rtc/compile_kernel.hpp>
// TODO include only if USE_RTC is set? #include <rtc/hip.hpp>
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#endif
#include <rtc/tmp_dir.hpp> #include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <cassert> #include <cassert>
#include <numeric>
#include <deque> #include <deque>
#include <rtc/hiprtc_enable_env.hpp> #include <fstream>
#include <ck/host/stringutils.hpp> #include <iostream>
#include <numeric>
#include <stdexcept>
namespace rtc { namespace rtc {
...@@ -106,6 +106,8 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o ...@@ -106,6 +106,8 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
return kernel{obj.data(), options.kernel_name}; return kernel{obj.data(), options.kernel_name};
} }
#ifdef HIPRTC_FOR_CODEGEN_TESTS
struct hiprtc_src_file struct hiprtc_src_file
{ {
hiprtc_src_file() = default; hiprtc_src_file() = default;
...@@ -274,20 +276,18 @@ static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_o ...@@ -274,20 +276,18 @@ static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_o
if(cos.size() != 1) if(cos.size() != 1)
std::runtime_error("No code object"); std::runtime_error("No code object");
auto& obj = cos.front(); auto& obj = cos.front();
return kernel{obj.data(), options.kernel_name}; return kernel{obj.data(), options.kernel_name};
} }
#endif
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options) kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC))) #ifdef HIPRTC_FOR_CODEGEN_TESTS
{ return hiprtc_compile_kernel(srcs, options);
return hiprtc_compile_kernel(srcs, options); #else
} return clang_compile_kernel(srcs, options);
else #endif
{
return clang_compile_kernel(srcs, options);
}
} }
} // namespace rtc } // namespace rtc
...@@ -1127,4 +1127,4 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -1127,4 +1127,4 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
\ No newline at end of file
...@@ -340,8 +340,8 @@ struct Bilinear ...@@ -340,8 +340,8 @@ struct Bilinear
}; };
template <> template <>
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t>( __host__ __device__ constexpr void
int8_t& y, const int32_t& x0, const int8_t& x1) const operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
{ {
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) + y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1)); beta_ * type_convert<float>(x1));
......
...@@ -36,7 +36,7 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>> ...@@ -36,7 +36,7 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{ {
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})..., return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
ck::forward<Y>(y).At(Number<Js>{})...); ck::forward<Y>(y).At(Number<Js>{})...);
} }
}; };
......
...@@ -113,7 +113,6 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept ...@@ -113,7 +113,6 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
return static_cast<T&&>(t_); return static_cast<T&&>(t_);
} }
// TODO
template<class T> struct is_const : false_type {}; template<class T> struct is_const : false_type {};
template<class T> struct is_const<const T> : true_type {}; template<class T> struct is_const<const T> : true_type {};
template< class T > template< class T >
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment