Unverified Commit 581d31b0 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add initial code generation (#762)



* Add code object op

* Formattting

* Add more value tests

* Formatting

* Fix from_value conversion from binary

* Formatting

* Dont use offload copy

* Remove iostream header

* Fix compilation errors

* Formatting

* Rename var

* Add missing files

* Formatting

* Remove duplicate variable

* Remove comment

* Template the function so sfinae will work

* Formatting

* Use template specialization since ADL is broken on hcc

* Formatting

* Annotate the constructor with HD for hcc

* Make variable const
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e1819f81
...@@ -28,7 +28,9 @@ function(generate_embed_source EMBED_NAME) ...@@ -28,7 +28,9 @@ function(generate_embed_source EMBED_NAME)
extern const char ${END_SYMBOL}[]; extern const char ${END_SYMBOL}[];
") ")
# TODO: Should use NAME_WLE
get_filename_component(BASE_NAME "${OBJECT}" NAME) get_filename_component(BASE_NAME "${OBJECT}" NAME)
string(REGEX REPLACE ".[A-Za-z0-9_]$" "" BASE_NAME ${BASE_NAME})
string(APPEND INIT_KERNELS " string(APPEND INIT_KERNELS "
{ \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} },
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <unordered_map>
#include <vector> #include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -91,6 +92,38 @@ inline bool starts_with(const std::string& value, const std::string& prefix) ...@@ -91,6 +92,38 @@ inline bool starts_with(const std::string& value, const std::string& prefix)
return std::equal(prefix.begin(), prefix.end(), value.begin()); return std::equal(prefix.begin(), prefix.end(), value.begin());
} }
template <class F>
inline std::string
interpolate_string(const std::string& input, F f, std::string start = "${", std::string end = "}")
{
std::string result = "";
result.reserve(input.size());
auto it = input.begin();
while(it != input.end())
{
auto next_start = std::search(it, input.end(), start.begin(), start.end());
auto next_end = std::search(next_start, input.end(), end.begin(), end.end());
result.append(it, next_start);
if(next_start == input.end())
break;
auto r = f(next_start + start.size(), next_end - end.size() + 1);
result.append(r.begin(), r.end());
it = next_end + 1;
}
return result;
}
inline std::string interpolate_string(const std::string& input,
const std::unordered_map<std::string, std::string>& vars)
{
return interpolate_string(input, [&](auto start, auto last) {
auto key = trim({start, last});
auto it = vars.find(key);
if(it == vars.end())
throw std::runtime_error("Unknown key: " + key);
return it->second;
});
}
inline std::string remove_prefix(std::string s, const std::string& prefix) inline std::string remove_prefix(std::string s, const std::string& prefix)
{ {
if(starts_with(s, prefix)) if(starts_with(s, prefix))
......
...@@ -10,6 +10,12 @@ if(NOT TARGET MIOpen) ...@@ -10,6 +10,12 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
include(Embed)
file(GLOB KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
add_library(migraphx_device add_library(migraphx_device
device/acos.cpp device/acos.cpp
device/acosh.cpp device/acosh.cpp
...@@ -105,6 +111,7 @@ add_library(migraphx_gpu ...@@ -105,6 +111,7 @@ add_library(migraphx_gpu
argmax.cpp argmax.cpp
argmin.cpp argmin.cpp
code_object_op.cpp code_object_op.cpp
compile_hip_code_object.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
...@@ -280,7 +287,7 @@ endif() ...@@ -280,7 +287,7 @@ endif()
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1) target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_compile_options(migraphx_gpu PRIVATE -std=c++17) target_compile_options(migraphx_gpu PRIVATE -std=c++17)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_gpu migraphx_device TARGETS migraphx_gpu migraphx_device
......
...@@ -30,6 +30,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -30,6 +30,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
assert(not srcs.empty()); assert(not srcs.empty());
tmp_dir td{}; tmp_dir td{};
params += " -Wno-cuda-compat";
if(params.find("-std=") == std::string::npos) if(params.find("-std=") == std::string::npos)
params += " --std=c++17"; params += " --std=c++17";
params += " -fno-gpu-rdc"; params += " -fno-gpu-rdc";
......
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props);
}
template <class T>
std::string generate_index_ints(const std::vector<T>& v)
{
return "index_ints<" + to_string_range(v) + ">{}";
}
std::string generate_cpp_type(shape::type_t t)
{
switch(t)
{
#define MIGRAPHX_GPU_GENERATE_TYPE_STRING(x, t) \
case shape::x: return #t;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_GPU_GENERATE_TYPE_STRING)
}
MIGRAPHX_THROW("Invalid type");
}
std::string generate_make_shape(const shape& s)
{
return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
")";
}
static const char* const make_tensor_template = R"__migraphx__(
template<>
struct make_tensor<${n}>
{
static __device__ auto apply(void* p)
{
return make_tensor_view(reinterpret_cast<${type}*>(p), make_shape(${lens}, ${strides}));
}
};
)__migraphx__";
std::string generate_make_tensor(std::size_t n, const shape& s)
{
return interpolate_string(make_tensor_template,
{{"n", std::to_string(n)},
{"type", generate_cpp_type(s.type())},
{"lens", generate_index_ints(s.lens())},
{"strides", generate_index_ints(s.strides())}});
}
std::string generate_args_hpp(const std::vector<shape>& inputs)
{
std::string inner;
for(std::size_t i = 0; i < inputs.size(); i++)
{
inner += generate_make_tensor(i, inputs[i]);
}
const std::string args_hpp = R"__migraphx__(
#ifndef MIGRAPHX_GUARD_AUTO_ARGS_HPP
#define MIGRAPHX_GUARD_AUTO_ARGS_HPP
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace migraphx {
__content__
} // namespace migraphx
#endif
)__migraphx__";
return replace_string(args_hpp, "__content__", inner);
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
std::vector<src_file> srcs;
std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(),
std::back_inserter(srcs),
[](auto&& p) {
auto&& name = p.first;
auto&& c = p.second;
auto path = fs::path{"migraphx"} / "kernels" / name;
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = generate_args_hpp(options.inputs);
srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -I.";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
return code_object_op{value::binary{cos.front()},
options.kernel_name,
options.global,
options.local,
options.inputs,
options.output};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,7 +21,7 @@ struct code_object_op ...@@ -21,7 +21,7 @@ struct code_object_op
std::size_t local; std::size_t local;
std::vector<shape> expected_inputs; std::vector<shape> expected_inputs;
shape output; shape output;
kernel k; kernel k{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_compile_options
{
std::size_t global;
std::size_t local;
std::vector<shape> inputs;
shape output;
std::string kernel_name = "kernel";
std::string params = "";
};
operation compile_hip_code_object(const std::string& content, hip_compile_options options);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_HIP_CODE_OBJECT_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ALGORITHM_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ALGORITHM_HPP
namespace migraphx {
struct less
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x < y;
}
};
struct greater
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
{
return x > y;
}
};
template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{
if(first != last)
{
Iterator next = first;
while(++next != last)
{
if(comp(*next, *first))
return next;
first = next;
}
}
return last;
}
template <class Iterator, class Compare>
constexpr bool is_sorted(Iterator first, Iterator last, Compare comp)
{
return is_sorted_until(first, last, comp) == last;
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_ARGS_HPP
#define MIGRAPHX_GUARD_KERNELS_ARGS_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
template <std::size_t N>
struct arg
{
};
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
// Use template specialization since ADL is broken on hcc
template <std::size_t>
struct make_tensor;
template <class F, std::size_t... Ns, class... Ts>
__device__ auto make_tensors_impl(F f, seq<Ns...>, Ts*... xs)
{
f(make_tensor<Ns>::apply(xs)...);
}
template <class... Ts>
__device__ auto make_tensors(Ts*... xs)
{
return [=](auto f) { make_tensors_impl(f, gens<sizeof...(Ts)>{}, xs...); };
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP
\ No newline at end of file
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#include <migraphx/kernels/types.hpp>
#include <type_traits>
namespace migraphx {
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
constexpr array& operator op(const array& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
constexpr array& operator op(const T& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
friend constexpr array operator binary_op(const array& x, const array& y) \
{ \
auto z = x; \
return z op y; \
} \
friend constexpr array operator binary_op(const array& x, const T& y) \
{ \
auto z = x; \
return z op y; \
} \
friend constexpr array operator binary_op(const T& x, const array& y) \
{ \
for(index_int i = 0; i < N; i++) \
y[i] = x op y[i]; \
return y; \
}
template <class T, index_int N>
struct array
{
T d[N];
constexpr T& operator[](index_int i) { return d[i]; }
constexpr const T& operator[](index_int i) const { return d[i]; }
constexpr T& front() { return d[0]; }
constexpr const T& front() const { return d[0]; }
constexpr T& back() { return d[N - 1]; }
constexpr const T& back() const { return d[N - 1]; }
constexpr T* data() { return d; }
constexpr const T* data() const { return d; }
constexpr std::integral_constant<index_int, N> size() const { return {}; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
constexpr T* end() { return d + size(); }
constexpr const T* end() const { return d + size(); }
constexpr T dot(const array& x) const
{
T result = 0;
for(index_int i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
constexpr T product() const
{
T result = 1;
for(index_int i = 0; i < N; i++)
result *= d[i];
return result;
}
constexpr T single(index_int width = 100) const
{
T result = 0;
T a = 1;
for(index_int i = 0; i < N; i++)
{
result += d[N - i - 1] * a;
a *= width;
}
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
MIGRAPHX_DEVICE_ARRAY_OP(/=, /)
MIGRAPHX_DEVICE_ARRAY_OP(%=, %)
MIGRAPHX_DEVICE_ARRAY_OP(&=, &)
MIGRAPHX_DEVICE_ARRAY_OP(|=, |)
MIGRAPHX_DEVICE_ARRAY_OP(^=, ^)
friend constexpr bool operator==(const array& x, const array& y)
{
for(index_int i = 0; i < N; i++)
{
if(x[i] != y[i])
return false;
}
return true;
}
friend constexpr bool operator!=(const array& x, const array& y) { return !(x == y); }
// This uses the product order rather than lexical order
friend constexpr bool operator<(const array& x, const array& y)
{
for(index_int i = 0; i < N; i++)
{
if(not(x[i] < y[i]))
return false;
}
return true;
}
friend constexpr bool operator>(const array& x, const array& y) { return y < x; }
friend constexpr bool operator<=(const array& x, const array& y) { return (x < y) or (x == y); }
friend constexpr bool operator>=(const array& x, const array& y) { return (y < x) or (x == y); }
constexpr array carry(array result) const
{
uint32_t overflow = 0;
for(std::ptrdiff_t i = result.size() - 1; i > 0; i--)
{
auto z = result[i] + overflow;
// Reset overflow
overflow = 0;
// Compute overflow using while loop instead of mod
while(z >= d[i])
{
z -= d[i];
overflow += 1;
}
result[i] = z;
}
result[0] += overflow;
return result;
}
};
template <class T, T... xs>
struct integral_const_array : array<T, sizeof...(xs)>
{
using base_array = array<T, sizeof...(xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {}
};
template <index_int... Ns>
using index_ints = integral_const_array<index_int, Ns...>;
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
struct index
{
index_int global = 0;
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT
};
inline __device__ index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_SHAPE_HPP
#include <migraphx/kernels/array.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class Lens, class Strides>
struct shape
{
using index_array = typename Lens::base_array;
Lens lens = {};
Strides strides = {};
constexpr shape() = default;
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); }
constexpr index_int element_space() const { return strides.dot(lens - 1); }
constexpr bool packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; }
constexpr bool transposed() const
{
if(broadcasted())
{
index_array s;
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{
if(strides[i] != 0)
{
s[j] = strides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
}
else
{
return not is_sorted(strides.begin(), strides.end(), greater{});
}
}
constexpr bool standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); }
constexpr index_int index(std::initializer_list<index_int> x) const
{
index_int idx = 0;
for(index_int i = 0; i < x.size(); i++)
idx += *(x.begin() + i) * strides[i];
return idx;
}
constexpr index_int index(index_int i) const
{
if(this->standard())
return i;
else
{
const index_int rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
const index_int len = this->lens[k];
const index_int slen = s * len;
const index_int idx = (i % slen) / s;
result += stride * idx;
s = slen;
}
return result;
}
}
constexpr index_array multi(index_int idx) const
{
index_array result;
index_int tidx = idx;
for(std::ptrdiff_t is = result.size() - 1; is > 0; is--)
{
result[is] = tidx % lens[is];
tidx = tidx / lens[is];
}
result[0] = tidx;
return result;
}
};
template <class Lens, class Strides>
constexpr shape<Lens, Strides> make_shape(Lens lens, Strides strides)
{
return {lens, strides};
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
#include <migraphx/kernels/shape.hpp>
namespace migraphx {
template <class T, class Shape>
struct tensor_view
{
constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); }
template <class U>
constexpr T& operator[](U i) const
{
return x[get_shape().index(i)];
}
constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); }
constexpr T* end() const { return data() + size(); }
T* x;
};
template <class T, class Shape>
constexpr tensor_view<T, Shape> make_tensor_view(T* x, Shape)
{
return {x};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <hip/hip_runtime.h>
namespace migraphx {
using index_int = std::uint32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment