Unverified Commit 29fa2666 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add gpu driver and improvements to pointwise codegen (#851)



* Add method to compile pointwise

* Formatting

* Add lambda

* Add semicolon

* Rename variable

* Add driver to run jit kernels

* Formatting

* Add context

* Formatting

* Make seperate driver folder

* Add more general gpu driver

* Formatting

* Print out wll time

* Formatting

* Run multiple times and skip first run

* Formatting

* Seperate time_op

* Run an op for comparison

* Formatting

* Add debug asserts

* Formatting

* Change parameer name

* Formatting

* Fix argument order

* Formatting

* Add preloading

* Formatting

* Allow a different data type

* Formatting

* Pipeline transformations

* Formatting

* Add vectorization

* Formatting

* Reduce dims

* Formatting

* Compile with launch params as constant

* Formatting

* Make sure buffer can be vecotrized

* Formatting

* Enable vectorization and preloading

* Formatting

* Add print header

* Formatting

* Avoid allocating to large of LDS

* Formatting

* Add some vec functions to a seperate header

* Formatting

* Add stride loops

* Formatting

* Improve the transform pipeline

* Formatting

* Add const

* Fix shape check

* Formatting

* Just check stride axis is zero

* Remove extra finc_vector_axis overload

* Simplify some mroe functions

* Formatting

* Remove some more extra functions

* Formatting

* Simplify more decltypes

* Add another const

* Fix test

* Get buffer pointer different for older compilers
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 30966f6b
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
struct run_op : action<run_op>
{
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto name = v.at("name").to<std::string>();
if(not contains(name, "::"))
name = "gpu::" + name;
auto op = make_op(name);
double t = time_op(ctx, op, inputs);
std::cout << op << ": " << t << "ms" << std::endl;
}
};
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -51,6 +51,7 @@ struct code_object_op
os << "symbol_name=" << op.symbol_name << ",";
os << "global=" << op.global << ",";
os << "local=" << op.local << ",";
os << "]";
return os;
}
};
......
......@@ -14,8 +14,9 @@ struct hip_compile_options
std::size_t local;
std::vector<shape> inputs;
shape output;
std::string kernel_name = "kernel";
std::string params = "";
std::string kernel_name = "kernel";
std::string params = "";
std::vector<shape> reduced_inputs = {};
};
operation compile_hip_code_object(const std::string& content, hip_compile_options options);
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
operation
compile_pointwise(context& ctx, const std::vector<shape>& inputs, const std::string& lambda);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
......@@ -43,6 +43,59 @@ constexpr bool is_sorted(Iterator first, Iterator last, Compare comp)
return is_sorted_until(first, last, comp) == last;
}
template <class Iterator, class F>
constexpr F for_each(Iterator first, Iterator last, F f)
{
for(; first != last; ++first)
{
f(*first);
}
return f;
}
template <class Iterator, class Predicate>
constexpr Iterator find_if(Iterator first, Iterator last, Predicate p)
{
for(; first != last; ++first)
{
if(p(*first))
{
return first;
}
}
return last;
}
template <class Iterator, class T>
constexpr Iterator find(Iterator first, Iterator last, const T& value)
{
return find_if(first, last, [&](const auto& x) { return x == value; });
}
template <class Iterator1, class Iterator2>
constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, Iterator2 s_last)
{
for(;; ++first)
{
Iterator1 it = first;
for(Iterator2 s_it = s_first;; ++it, ++s_it)
{
if(s_it == s_last)
{
return first;
}
if(it == last)
{
return last;
}
if(!(*it == *s_it))
{
break;
}
}
}
}
} // namespace migraphx
#endif
......@@ -2,57 +2,26 @@
#define MIGRAPHX_GUARD_KERNELS_ARGS_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx {
template <std::size_t N>
struct arg
{
};
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
// Use template specialization since ADL is broken on hcc
template <std::size_t>
template <index_int>
struct make_tensor;
template <class F, std::size_t... Ns, class... Ts>
__device__ auto make_tensors_impl(F f, seq<Ns...>, Ts*... xs)
template <class F, index_int... Ns, class... Ts>
__device__ auto make_tensors_impl(F f, detail::seq<Ns...>, Ts*... xs)
{
f(make_tensor<Ns>::apply(xs)...);
return f(make_tensor<Ns>::apply(xs)...);
}
template <class... Ts>
__device__ auto make_tensors(Ts*... xs)
inline __device__ auto make_tensors()
{
return [=](auto f) { make_tensors_impl(f, gens<sizeof...(Ts)>{}, xs...); };
return [](auto*... xs) {
return [=](auto f) { return make_tensors_impl(f, detail::gens<sizeof...(xs)>{}, xs...); };
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP
\ No newline at end of file
#endif // MIGRAPHX_GUARD_KERNELS_ARGS_HPP
......@@ -2,7 +2,8 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#include <migraphx/kernels/types.hpp>
#include <type_traits>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
......@@ -41,8 +42,16 @@ template <class T, index_int N>
struct array
{
T d[N];
constexpr T& operator[](index_int i) { return d[i]; }
constexpr const T& operator[](index_int i) const { return d[i]; }
constexpr T& operator[](index_int i)
{
MIGRAPHX_ASSERT(i < N);
return d[i];
}
constexpr const T& operator[](index_int i) const
{
MIGRAPHX_ASSERT(i < N);
return d[i];
}
constexpr T& front() { return d[0]; }
constexpr const T& front() const { return d[0]; }
......@@ -53,7 +62,7 @@ struct array
constexpr T* data() { return d; }
constexpr const T* data() const { return d; }
constexpr std::integral_constant<index_int, N> size() const { return {}; }
constexpr index_constant<N> size() const { return {}; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
......@@ -142,6 +151,18 @@ struct array
result[0] += overflow;
return result;
}
template <class Stream>
friend constexpr const Stream& operator<<(const Stream& ss, const array& a)
{
for(index_int i = 0; i < N; i++)
{
if(i > 0)
ss << ", ";
ss << a[i];
}
return ss;
}
};
template <class T, T... xs>
......@@ -151,6 +172,18 @@ struct integral_const_array : array<T, sizeof...(xs)>
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {}
};
template <class T, T... xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f)
{
return integral_const_array<T, f(xs)...>{};
}
template <class T, T... xs, class U, U... ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f)
{
return integral_const_array<T, f(xs, ys)...>{};
}
template <index_int... Ns>
using index_ints = integral_const_array<index_int, Ns...>;
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h>
namespace migraphx {
inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
{
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \
assert_fail(xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
#else
#define MIGRAPHX_ASSERT(cond)
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp>
namespace migraphx {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <index_int>
using ignore = swallow;
namespace detail {
template <class R>
struct eval_helper
{
R result;
template <class F, class... Ts>
constexpr eval_helper(const F& f, Ts&&... xs) : result(f(static_cast<Ts>(xs)...))
{
}
};
template <>
struct eval_helper<void>
{
int result;
template <class F, class... Ts>
constexpr eval_helper(const F& f, Ts&&... xs) : result((f(static_cast<Ts>(xs)...), 0))
{
}
};
template <index_int...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <index_int... Xs, index_int... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <index_int N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, index_int... Ns>
constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
{
return f(index_constant<Ns>{}...);
}
template <index_int... N>
constexpr auto args_at(seq<N...>)
{
return [](ignore<N>..., auto x, auto...) { return x; };
}
} // namespace detail
template <class T>
constexpr auto always(T x)
{
return [=](auto&&...) { return x; };
}
template <index_int N, class F>
constexpr auto sequence_c(F&& f)
{
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class G>
constexpr auto by(F f, G g)
{
return [=](auto... xs) {
return detail::eval_helper<decltype(g(f(xs)...))>{g, f(xs)...}.result;
};
}
template <class F>
constexpr auto by(F f)
{
return by([=](auto x) { return (f(x), 0); }, always(0));
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
}
template <class F>
constexpr void each_args(F)
{
}
template <class... Ts>
auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <index_int N>
constexpr auto arg_c()
{
return [](auto... xs) { return detail::args_at(detail::gens<N>{})(xs...); };
}
template <class IntegralConstant>
constexpr auto arg(IntegralConstant ic)
{
return arg_c<ic>();
}
inline constexpr auto rotate_last()
{
return [](auto... xs) {
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
}
template <class F>
constexpr auto transform_args(F f)
{
return [=](auto... xs) {
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
}
template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs)
{
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
}
#define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
......@@ -12,9 +12,43 @@ struct index
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nglobal() const
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
return blockDim.x * gridDim.x;
#endif
}
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT
__device__ index_int nlocal() const
{
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
#else
return blockDim.x;
#endif
}
template <class F>
__device__ void global_stride(index_int n, F f) const
{
const auto stride = nglobal();
for(index_int i = global; i < n; i += stride)
{
f(i);
}
}
template <class F>
__device__ void local_stride(index_int n, F f) const
{
const auto stride = nlocal();
for(index_int i = local; i < n; i += stride)
{
f(i);
}
}
};
inline __device__ index make_index()
......
#ifndef MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#include <migraphx/kernels/types.hpp>
namespace migraphx {
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
using value_type = T;
using type = integral_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \
{ \
return {}; \
}
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \
integral_constant<T, v>) noexcept \
{ \
return {}; \
}
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(-)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(*)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(/)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(%)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>>)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<<)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(<=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(!)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(~)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(+)
MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(-)
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
template <index_int N>
using index_constant = integral_constant<index_int, N>;
template <auto v>
static constexpr auto _c = integral_constant<decltype(v), v>{};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#define MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp>
namespace migraphx {
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...);
});
});
}
template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps)
{
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) {
auto idx = make_index();
pointwise_tensor(idx, f, xs...);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_POINTWISE_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
#define MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
namespace migraphx {
template <class Shape>
constexpr bool is_preloadable()
{
Shape s{};
if(not s.broadcasted())
return false;
}
template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss)
{
return [=](auto f, auto... g) {
index_int offset = 0;
auto each = [&](auto x) {
constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted())
return f(x, offset, false_type{});
else if constexpr((s.elements() - size) < 64)
return f(x, offset, false_type{});
else
{
auto pre_offset = offset;
offset += size;
offset += offset % 4;
return f(x, pre_offset, true_type{});
}
};
return by(each, g...)(ss...);
};
}
template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...)
{
index_int size = 0;
traverse_preload<T>(Shapes{}...)(
[&](auto s, auto offset, auto) { size = offset + s.element_space(); });
return size;
}
template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
auto invoke = [&](auto... ys) {
__syncthreads();
f(ys...);
};
traverse_preload<T>(xs...)(
[&](auto x, auto offset, auto copy) {
if constexpr(copy)
{
auto v = vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
return x;
}
},
invoke);
}
template <class T>
struct remove_vec
{
using type = T;
};
template <class T, index_int N>
struct remove_vec<vec<T, N>>
{
using type = T;
};
template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs)
{
using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...);
const index_int max_size = 512 * sizeof(type);
return [=](auto f) {
if constexpr(size > 0 and size < max_size)
{
__shared__ type buffer[size];
preload_copy(idx, f, buffer, xs...);
}
else
{
f(xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class F, class G>
struct on_exit
{
F f;
G g;
template <class T>
__host__ __device__ auto operator()(T x) const
{
return f(x);
}
__host__ __device__ ~on_exit() { f(g); }
};
template <class PrivateMIGraphXTypeNameProbe>
constexpr auto print_type_name_probe()
{
constexpr auto name = __PRETTY_FUNCTION__;
constexpr auto size = sizeof(__PRETTY_FUNCTION__);
constexpr auto parameter_name = "PrivateMIGraphXTypeNameProbe = ";
constexpr auto parameter_name_size = sizeof("PrivateMIGraphXTypeNameProbe = ") - 1;
constexpr auto begin =
search(name, name + size, parameter_name, parameter_name + parameter_name_size);
static_assert(begin < name + size, "Type probe not found.");
constexpr auto start = begin + parameter_name_size;
constexpr auto last = find_if(start, name + size, [](auto c) { return c == ']' or c == ';'; });
return [=](const auto& s) { s.print_string(start, last - start); };
}
template <class T>
struct type_printer
{
template <class Stream>
friend constexpr const Stream& operator<<(const Stream& s, type_printer)
{
print_type_name_probe<T>()(s);
return s;
}
};
template <class T>
constexpr type_printer<T> type_of()
{
return {};
}
template <class T>
constexpr type_printer<T> type_of(T)
{
return {};
}
template <class T>
constexpr type_printer<typename T::type> sub_type_of()
{
return {};
}
template <class T>
constexpr type_printer<typename T::type> sub_type_of(T)
{
return {};
}
template <class F>
struct basic_printer
{
F f;
__host__ __device__ const basic_printer& print_long(long value) const
{
f([&] { printf("%li", value); });
return *this;
}
__host__ __device__ const basic_printer& print_ulong(unsigned long value) const
{
f([&] { printf("%lu", value); });
return *this;
}
__host__ __device__ const basic_printer& print_char(char value) const
{
f([&] { printf("%c", value); });
return *this;
}
__host__ __device__ const basic_printer& print_string(const char* value) const
{
f([&] { printf("%s", value); });
return *this;
}
__host__ __device__ const basic_printer& print_string(const char* value, int size) const
{
f([&] { printf("%.*s", size, value); });
return *this;
}
__host__ __device__ const basic_printer& print_double(double value) const
{
f([&] { printf("%f", value); });
return *this;
}
__host__ __device__ const basic_printer& print_bool(bool value) const
{
f([&] {
if(value)
printf("true");
else
printf("false");
});
return *this;
}
__host__ __device__ const basic_printer& operator<<(short value) const
{
return print_long(value);
}
__host__ __device__ const basic_printer& operator<<(unsigned short value) const
{
return print_ulong(value);
}
__host__ __device__ const basic_printer& operator<<(int value) const
{
return print_long(value);
}
__host__ __device__ const basic_printer& operator<<(unsigned int value) const
{
return print_ulong(value);
}
__host__ __device__ const basic_printer& operator<<(long value) const
{
return print_long(value);
}
__host__ __device__ const basic_printer& operator<<(unsigned long value) const
{
return print_ulong(value);
}
__host__ __device__ const basic_printer& operator<<(float value) const
{
return print_double(value);
}
__host__ __device__ const basic_printer& operator<<(double value) const
{
return print_double(value);
}
__host__ __device__ const basic_printer& operator<<(bool value) const
{
return print_bool(value);
}
__host__ __device__ const basic_printer& operator<<(char value) const
{
return print_char(value);
}
__host__ __device__ const basic_printer& operator<<(unsigned char value) const
{
return print_char(value);
}
__host__ __device__ const basic_printer& operator<<(const char* value) const
{
return print_string(value);
}
};
template <class F>
constexpr basic_printer<F> make_printer(F f)
{
return {f};
}
template <class F, class G>
constexpr basic_printer<on_exit<F, G>> make_printer(F f, G g)
{
return {{f, g}};
}
inline __device__ auto cout()
{
return make_printer([](auto f) { f(); });
}
inline __device__ auto coutln()
{
return make_printer([](auto f) { f(); }, [] { printf("\n"); });
}
template <class F, class... Ts>
__device__ void print_each(F f, Ts... xs)
{
each_args([&](auto x) { f() << x; }, xs...);
}
template <class F, class... Ts>
__device__ void print_each_once(F f, Ts... xs)
{
auto idx = make_index();
if(idx.global == 0)
print_each(f, xs...);
}
template <class... Ts>
__device__ void print(Ts... xs)
{
print_each(&cout, xs...);
}
template <class... Ts>
__device__ void print_once(Ts... xs)
{
print_each_once(&cout, xs...);
}
template <class... Ts>
__device__ void println(Ts... xs)
{
print_each(&coutln, xs...);
}
template <class... Ts>
__device__ void println_once(Ts... xs)
{
print_each_once(&coutln, xs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRINT_HPP
......@@ -19,7 +19,7 @@ struct shape
constexpr index_int elements() const { return lens.product(); }
constexpr index_int element_space() const { return strides.dot(lens - 1); }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; }
constexpr bool packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; }
......@@ -92,6 +92,15 @@ struct shape
result[0] = tidx;
return result;
}
constexpr shape get_shape() const { return *this; }
template <class Stream>
friend constexpr const Stream& operator<<(const Stream& ss, const shape& s)
{
ss << "{" << s.lens << "}, {" << s.strides << "}";
return ss;
}
};
template <class Lens, class Strides>
......
......@@ -2,18 +2,22 @@
#define MIGRAPHX_GUARD_KERNELS_TENSOR_VIEW_HPP
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
template <class T, class Shape>
struct tensor_view
{
using type = T;
constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); }
template <class U>
constexpr T& operator[](U i) const
{
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space());
return x[get_shape().index(i)];
}
......@@ -22,6 +26,13 @@ struct tensor_view
constexpr T* begin() const { return data(); }
constexpr T* end() const { return data() + size(); }
template <class U>
constexpr tensor_view<U, Shape> with(U* y) const
{
static_assert(sizeof(T) == sizeof(U), "Not the same size");
return {y};
}
T* x;
};
......
......@@ -9,6 +9,9 @@ using index_int = std::uint32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
template <class T, index_int N>
using vec = T __attribute__((ext_vector_type(N)));
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_VEC_HPP
#define MIGRAPHX_GUARD_KERNELS_VEC_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace migraphx {
template <class T, index_int N>
constexpr auto vec_size(vec<T, N>)
{
return index_constant<N>{};
}
template <class T>
constexpr auto vec_size(T, ...)
{
return index_constant<0>{};
}
template <class T>
constexpr auto vec_size()
{
return decltype(vec_size(T{})){};
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T* x)
{
if constexpr(N == 0)
return x;
else
return reinterpret_cast<vec<T, N>*>(x);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#define MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
template <class T>
constexpr auto tensor_vec_size(T)
{
return vec_size<typename T::type>();
}
template <index_int N, class Shape>
constexpr auto as_vec_shape(Shape s)
{
auto lens = transform(s.lens, s.strides, [](auto len, auto stride) {
if(stride == 1)
return len / N;
else
return len;
});
auto strides = transform(s.strides, [](auto stride) {
if(stride == 1)
return stride;
return stride / N;
});
MIGRAPHX_ASSERT(make_shape(lens, strides).element_space() * N == s.element_space());
return make_shape(lens, strides);
}
template <index_int N, class T>
__device__ __host__ auto as_vec(T x)
{
if constexpr(N == 0)
return x;
else
return make_tensor_view(as_vec<N>(x.data()), as_vec_shape<N>(x.get_shape()));
}
template <index_int N, class T, class Axis>
constexpr auto tensor_step(T x, Axis)
{
if constexpr(N == 0)
{
return x;
}
else
{
constexpr auto s = decltype(x.get_shape()){};
MIGRAPHX_ASSERT(s.strides[Axis{}] == 0);
return sequence(x.get_shape().lens.size(), [&](auto... is) {
auto lens = transform(s.lens, index_ints<is...>{}, [&](auto i, auto j) {
constexpr auto axis = Axis{};
if(j == axis)
return i / N;
else
return i;
});
return make_tensor_view(x.data(), make_shape(lens, s.strides));
});
}
}
template <class IntegralConstant, class T>
__device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
{
return as_vec<ic>(x);
}
template <class... Shapes>
constexpr index_int find_vector_axis(Shapes... ss)
{
index_int axis = 0;
bool b = false;
by([&](auto s) {
if(s.broadcasted() or b)
return;
auto it = find(s.strides.begin(), s.strides.end(), 1);
if(it == s.strides.end())
return;
axis = it - s.strides.begin();
b = true;
})(ss...);
return axis;
}
template <index_int N, class Axis, class... Shapes>
constexpr auto is_vectorizable(Axis axis, Shapes... ss)
{
return (((ss.lens[axis] % N) == 0 and (ss.strides[axis] == 1 or ss.strides[axis] == 0)) and
...);
}
template <index_int N, class... Shapes>
constexpr bool is_vectorizable(Shapes... ss)
{
return (is_vectorizable<N>(ss, find_vector_axis(ss)) and ...);
}
template <class P>
constexpr auto find_vectorize_size(P pred)
{
if constexpr(pred(_c<4>))
return _c<4>;
else if constexpr(pred(_c<2>))
return _c<2>;
else
return _c<0>;
}
template <class T>
__host__ __device__ auto vectorize(T x)
{
if constexpr(vec_size<T>() == 0)
{
constexpr auto n =
find_vectorize_size([&](auto i) { return _c<is_vectorizable<i>(x.get_shape())>; });
return as_vec<n>(x);
}
else
{
return x;
}
}
inline __device__ __host__ auto auto_vectorize()
{
return [](auto... xs) {
return [=](auto f) {
// TODO: Just check there a single axis of 1
constexpr bool packed_or_broadcasted =
((xs.get_shape().packed() or xs.get_shape().broadcasted()) and ...);
if constexpr(packed_or_broadcasted)
{
constexpr auto axis = find_vector_axis(xs.get_shape()...);
constexpr auto n = find_vectorize_size(
[&](auto i) { return _c<is_vectorizable<i>(axis, xs.get_shape()...)>; });
by(
[&](auto x) {
constexpr auto s = x.get_shape();
if constexpr(s.strides[axis] == 0)
return tensor_step<n>(x, axis);
else
return as_vec<n>(x);
},
f)(xs...);
}
else
{
f(xs...);
}
};
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VECTORIZE_HPP
......@@ -224,23 +224,24 @@ std::vector<value>& get_array_throw(const std::shared_ptr<value_base_impl>& x)
return *a;
}
value* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key)
template <class T>
T* find_impl(const std::shared_ptr<value_base_impl>& x, const std::string& key, T* end)
{
auto* a = if_array_impl(x);
if(a == nullptr)
return nullptr;
return end;
auto* lookup = x->if_object();
if(lookup == nullptr)
return nullptr;
return end;
auto it = lookup->find(key);
if(it == lookup->end())
return a->data() + a->size();
return end;
return std::addressof((*a)[it->second]);
}
value* value::find(const std::string& pkey) { return find_impl(x, pkey); }
value* value::find(const std::string& pkey) { return find_impl(x, pkey, this->end()); }
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey); }
const value* value::find(const std::string& pkey) const { return find_impl(x, pkey, this->end()); }
bool value::contains(const std::string& pkey) const
{
const auto* it = find(pkey);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment