Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
......@@ -21,6 +21,26 @@ struct greater
}
};
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
while(first != last)
{
*d_first++ = *first++;
}
return d_first;
}
template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{
......@@ -96,6 +116,35 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
}
}
template <class InputIt1, class InputIt2, class T, class BinaryOperation1, class BinaryOperation2>
constexpr T inner_product(InputIt1 first1,
InputIt1 last1,
InputIt2 first2,
T init,
BinaryOperation1 op1,
BinaryOperation2 op2)
{
while(first1 != last1)
{
init = op1(init, op2(*first1, *first2));
++first1;
++first2;
}
return init;
}
template <class InputIt1, class InputIt2, class T>
constexpr T inner_product(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init)
{
return inner_product(
first1,
last1,
first2,
init,
[](auto x, auto y) { return x + y; },
[](auto x, auto y) { return x * y; });
}
} // namespace migraphx
#endif
......@@ -74,6 +74,7 @@ struct array
constexpr const T* data() const { return d; }
constexpr index_constant<N> size() const { return {}; }
constexpr auto empty() const { return size() == _c<0>; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
......@@ -145,8 +146,8 @@ struct array
constexpr array carry(array result) const
{
uint32_t overflow = 0;
for(std::ptrdiff_t i = result.size() - 1; i > 0; i--)
index_int overflow = 0;
for(diff_int i = result.size() - 1; i > 0; i--)
{
auto z = result[i] + overflow;
// Reset overflow
......
......@@ -42,6 +42,32 @@ struct print_buffer
pos++;
}
}
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
......@@ -54,14 +80,36 @@ struct print_buffer
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
print_buffer<1024> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
......@@ -73,13 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE
#define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
}(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_KERNELS_DPP_HPP
#define MIGRAPHX_GUARD_KERNELS_DPP_HPP
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
#if MIGRAPHX_HAS_DPP
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15: y = 0x142; break;
case 31: y = 0x143; break;
default: MIGRAPHX_UNREACHABLE();
}
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output{};
type input{};
// cppcheck-suppress unreadVariable
input.data = x;
for(index_int i = 0; i < n; i++)
{
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
......@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx {
struct swallow
......@@ -129,7 +137,7 @@ constexpr auto by(F f)
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
swallow{(f(static_cast<Ts&&>(xs)), 0)...};
}
template <class F>
......@@ -161,6 +169,18 @@ constexpr auto pack(Ts... xs)
return [=](auto f) { return f(xs...); };
}
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
......@@ -191,39 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>();
}
inline constexpr auto rotate_last()
template <class F>
constexpr auto make_transform(F f)
{
return [](auto... xs) {
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
}
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F>
constexpr auto transform_args(F f)
{
return [=](auto... xs) {
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
return f;
}
template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs)
{
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
}
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// identity transform
inline constexpr auto transform_args()
{
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
// Rotate the first argument to the last argument
inline constexpr auto rotate_last()
{
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
......@@ -3,6 +3,7 @@
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace migraphx {
......@@ -12,23 +13,23 @@ struct index
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; }
#else
__device__ index_int nglobal() const
{
return blockDim.x * gridDim.x; // NOLINT
#endif
}
#endif
__device__ index_int nlocal() const
{
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; }
#else
return blockDim.x; // NOLINT
#endif
__device__ index_int nlocal() const
{
return blockDim.x; // NOLINT
}
#endif
template <class F>
__device__ void global_stride(index_int n, F f) const
......
......@@ -48,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
......@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>;
template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
......@@ -40,12 +40,31 @@ constexpr T as_float(T x)
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::vec<migraphx::half, 2> x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::vec<migraphx::half, 2>{fname(x, xs...)}); \
template <class... Ts, index_int N, MIGRAPHX_REQUIRES(N % 2 == 0 && (N > 2))> \
auto __device__ name(migraphx::vec<migraphx::half, N> x, Ts... xs) \
{ \
return vec_packed_transform<2>(x, xs...)( \
[](auto... ys) -> migraphx::vec<migraphx::half, 2> { return fname(ys...); }); \
}
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
......@@ -59,6 +78,7 @@ MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
......@@ -103,6 +123,7 @@ MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
......@@ -110,12 +131,65 @@ MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
{
return where(a < b, b, a);
}
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto min(const T& a, const T& b)
{
return where(a < b, a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto max(const T& a, const U& b)
{
return max<common_type_t<T, U>>(a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto min(const T& a, const U& b)
{
return min<common_type_t<T, U>>(a, b);
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
......@@ -129,7 +203,10 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
......@@ -140,18 +217,6 @@ MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto max(const T& a, const U& b)
{
return where(a < b, b, a);
}
template <class T, class U>
constexpr auto min(const T& a, const U& b)
{
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP
#define MIGRAPHX_GUARD_KERNELS_OPS_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/math.hpp>
namespace migraphx {
namespace op {
struct sum
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
......@@ -17,7 +18,7 @@ struct sum
struct product
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
......@@ -26,7 +27,7 @@ struct product
struct id
{
template <class T>
constexpr auto operator()(T x) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
......@@ -34,40 +35,39 @@ struct id
struct mean
{
size_t item_num = 1;
index_int item_num = 1;
template <class T>
constexpr auto operator()(T x) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max_f
struct max
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x > y) ? x : y;
return migraphx::max(x, y);
}
};
inline constexpr auto max = max_f{};
struct min_f
struct min
{
template <class T, class U>
constexpr auto operator()(T x, U y) const
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x < y) ? x : y;
return migraphx::min(x, y);
}
};
inline constexpr auto min = min_f{};
} // namespace op
struct lowest
{
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::lowest();
return numeric_lowest<T>();
}
};
......@@ -76,9 +76,8 @@ struct highest
template <class T>
constexpr operator T() const
{
return std::numeric_limits<T>::max();
return numeric_max<T>();
}
};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#endif // MIGRAPHX_GUARD_KERNELS_OPS_HPP
......@@ -38,22 +38,17 @@ constexpr implicit_conversion_op<T> implicit_conversion(T x)
template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
});
idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(xs[i]...)); });
}
template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps)
template <class... Transforms>
__device__ auto pointwise(index idx, Transforms... transforms)
{
auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) {
auto idx = make_index();
pointwise_tensor(idx, f, xs...);
});
return [=](auto f, auto*... ps) {
auto t = transform_args(make_tensors(), rotate_last(), transforms...);
t(ps...)([&](auto... xs) { pointwise_tensor(idx, f, xs...); });
};
}
} // namespace migraphx
......
......@@ -3,18 +3,37 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/vec.hpp>
namespace migraphx {
template <class T>
struct remove_vec_impl
{
using type = T;
};
template <class T, index_int N>
struct remove_vec_impl<vec<T, N>>
{
using type = T;
};
template <class T>
using remove_vec = typename remove_vec_impl<T>::type;
template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss)
{
return [=](auto f, auto... g) {
index_int offset = 0;
auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted() or (s.elements() - size) < 64)
constexpr auto size = s.element_space();
if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{});
else
{
......@@ -56,7 +75,7 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{
if constexpr(decltype(tensor_vec_size(x)){} == 0)
{
auto v = vectorize(x);
auto v = auto_vectorize(x);
auto b = as_vec(tensor_vec_size(v), buffer + offset);
idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
......@@ -78,23 +97,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
invoke);
}
template <class T>
struct remove_vec
template <class T, class Shape>
struct shape_type : Shape
{
using type = T;
};
template <class T, index_int N>
struct remove_vec<vec<T, N>>
template <class T>
constexpr auto make_shape_type(T)
{
using type = T;
};
return shape_type<typename T::type, typename T::shape_type>{};
}
template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs)
{
using type = typename remove_vec<T>::type;
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
using type = remove_vec<T>;
constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type);
return [=](auto f) {
if constexpr(size > 0 and size < max_size)
......@@ -109,5 +128,47 @@ __device__ auto preload(index idx, Ts... xs)
};
}
inline __device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto out, auto... xs) {
preload<typename decltype(out)::type>(idx, xs...)([&](auto... ys) { f(out, ys...); });
});
}
template <bool B, class T>
__device__ auto preload_copy(index idx, T x)
{
return [=](auto f) {
if constexpr(B)
{
using type = typename T::type;
constexpr auto size = get_shape_c<T>{}.element_space();
__shared__ type buffer[size];
// TODO: Always vecotrize when size > 4, and then use a second loop for remainder
constexpr auto n = find_vectorize_size([&](auto i) { return (size % i) == 0; });
auto input = as_vec<n>(remove_bool(x.data()));
auto b = as_vec<n>(remove_bool(buffer));
idx.local_stride(size / n, [&](auto i) { b[i] = input[i]; });
return f(x.with(buffer));
}
else
{
return f(x);
}
};
}
template <bool... Bs>
__device__ auto auto_preload(index idx)
{
return make_transform([=](auto f, auto... xs) {
auto invoke = [=](auto... ys) {
__syncthreads();
f(ys...);
};
join(invoke, preload_copy<Bs>(idx, xs)...);
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_PRELOAD_HPP
......@@ -140,6 +140,10 @@ struct basic_printer
{
return print_ulong(value);
}
__host__ __device__ const basic_printer& operator<<(migraphx::half value) const
{
return print_double(value);
}
__host__ __device__ const basic_printer& operator<<(float value) const
{
return print_double(value);
......
#ifndef MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#define MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
#include <migraphx/kernels/dpp.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ops.hpp>
namespace migraphx {
#if MIGRAPHX_HAS_DPP
template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op)
{
T out{};
out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(4), 0xf, 0xe>(in);
in = op(in, out);
out = dpp_mov<dpp_row_shr(8), 0xf, 0xc>(in);
in = op(in, out);
#if __AMDGCN_WAVEFRONT_SIZE == 64
out = dpp_mov<dpp_row_bcast(15), 0xa>(in);
in = op(in, out);
out = dpp_mov<dpp_row_bcast(31), 0xc>(in);
in = op(in, out);
#endif
}
#if defined(MIGRAPHX_USE_CLANG_TIDY) || defined(CPPCHECK)
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) x = 1
#elif __AMDGCN_WAVEFRONT_SIZE == 64
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:15 row_mask:0xa\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_bcast:31 row_mask:0xc\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
#else
#define MIGRAPHX_DPP_REDUCE_ASM(x, ins) \
__asm__ volatile("s_nop 4\n" #ins " %0 %0 %0 row_shr:1\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:2\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:4 bank_mask:0xe\n" \
"s_nop 1\n" #ins " %0 %0 %0 row_shr:8 bank_mask:0xc\n" \
"s_nop 1\n" \
"s_nop 1\n" \
: "=v"(x) \
: "0"(x))
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_DPP_REDUCE(op, prefix) \
__device__ inline void dpp_reduce(double& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f64); } \
__device__ inline void dpp_reduce(float& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f32); } \
__device__ inline void dpp_reduce(half& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_f16); } \
__device__ inline void dpp_reduce(int32_t& x, op) \
{ \
MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); \
} \
__device__ inline void dpp_reduce(uint32_t& x, op) { MIGRAPHX_DPP_REDUCE_ASM(x, prefix##_u32); }
MIGRAPHX_DPP_REDUCE(op::sum, v_add)
MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16;
#else
constexpr index_int lanes_per_thread = 64;
#endif
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op);
const auto ldsidx = idx.local / lanes_per_thread;
if((idx.local % lanes_per_thread) == lanes_per_thread - 1)
{
buffer[ldsidx] = x;
}
__syncthreads();
type y = init;
for(index_int i = 0; i < idx.nlocal() / lanes_per_thread; i++)
{
y = op(y, buffer[i]);
}
return y;
}
#else
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
const index_int index = 2 * s * idx.local;
if(index + s < idx.nlocal())
{
buffer[index] = op(buffer[index], buffer[index + s]);
}
__syncthreads();
}
return buffer[0];
}
#endif
template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i)
{
constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens,
[](index_int x, index_int y) -> index_int {
if(x == y)
return 1;
return x;
});
;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s);
}
namespace reduce {
template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
return [=](auto x, auto... xs) {
// TODO: assert all elements are the same
return f(slicer(x), slicer(xs)...);
};
}
struct block
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx,
op,
init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
struct lane
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
using type = typename decltype(x)::type;
type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
r = op(r, read(x[j], xs[j]...));
}
return r;
});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements, [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i);
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
} // namespace reduce
template <class Algo,
class Op,
class T,
class Input,
class Output,
class ReadInput,
class WriteOuput>
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write)
{
Algo::template run<Output>([&](auto out_idx, auto r) {
auto x = r.reduce(op, init, read)(input);
r.outer([&] { output[out_idx] = write(x); });
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
......@@ -3,14 +3,15 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/array.hpp>
namespace migraphx {
struct max_pool
{
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest(); }
MIGRAPHX_DEVICE_CONSTEXPR auto init() { return lowest{}; }
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T operator()(T x, T y)
......@@ -19,7 +20,7 @@ struct max_pool
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t)
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int)
{
return (x);
}
......@@ -36,28 +37,26 @@ struct avg_pool
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y)
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{
return (y == 0) ? 0.0 : (x / y);
}
};
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
const array<std::size_t, 2>& dims,
array<float, 2> xy,
Op pooling)
template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
{
array<int, 2> low{};
array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii)
for(index_int ii = 0; ii < xy.size(); ++ii)
{
if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{
return 0;
}
xy[ii] = max(xy[ii], 0.0f);
xy[ii] = migraphx::max(xy[ii], 0.0f);
low[ii] = xy[ii];
high[ii] = low[ii] + 1;
if(low[ii] >= dims[ii] - 1)
......@@ -65,36 +64,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
}
}
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
array<index_int, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23);
}
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data,
const array<float, 2>& roi_starts,
const array<float, 2>& bin_size,
const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size,
const array<std::size_t, 2>& dims,
float roi_offset,
Op op)
template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
const array<float, 2>& roi_starts,
const array<float, 2>& bin_size,
const array<int, 2>& idx,
const array<index_int, 2>& bin_grid_size,
const array<index_int, 2>& dims,
float roi_offset,
Op op)
{
T output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix};
array<index_int, 2> id = {iy, ix};
array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
......@@ -120,21 +119,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
}
template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t, Settings s)
{
auto index = make_index();
const auto* x = x_t.data();
const auto* rois = rois_t.data();
const auto* ind = ind_t.data();
auto* out_ptr = y_t.data();
auto index = make_index();
const auto x = x_t.begin();
const auto rois = rois_t.begin();
const auto ind = ind_t.begin();
// input shape
auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]};
array<index_int, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal();
auto out_s = y_t.get_shape();
......@@ -142,8 +139,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
// output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]};
const auto& out_lens = out_s.lens;
array<index_int, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride)
{
......@@ -153,8 +150,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
int ph = idx[2];
int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale};
......@@ -163,40 +160,41 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
array<float, 2> roi_size{};
array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{};
array<index_int, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii)
for(index_int ii = 0; ii < roi_size.size(); ++ii)
{
roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f);
roi_size[ii] = migraphx::max(roi_size[ii], 1.0f);
bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] =
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] = (s.sampling_ratio > 0)
? s.sampling_ratio
: migraphx::ceil(roi_size[ii] / out_dims[ii]);
}
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling)
{
out_ptr[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
avg_pool{});
y_t[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
avg_pool{});
}
else
{
out_ptr[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
max_pool{});
y_t[i] = calc_pooling(offset_x,
roi_starts,
bin_size,
{ph, pw},
bin_grid_size,
in_dims,
s.roi_offset,
max_pool{});
}
}
}
......
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
auto index = make_index();
auto updates_shape = updates_t.get_shape();
index.global_stride(updates_shape.elements(), [&](auto i) {
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto k = indices_shape.lens.back();
auto q = indices_shape.lens.size();
auto updates_idx = updates_shape.multi(i);
auto indices_idx = indices_shape.multi(0);
copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices_t.begin() + indices_shape.index(indices_idx);
auto index_end = index_start + k;
auto out_idx = output_shape.multi(0);
copy(index_start, index_end, out_idx.begin());
copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
f(output_t[out_idx], updates_t[i]);
});
}
} // namespace migraphx
#endif
......@@ -17,35 +17,38 @@ struct shape
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); }
constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; }
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr bool packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; }
constexpr bool transposed() const
constexpr auto packed() const { return elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const
{
if(broadcasted())
{
index_array s;
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
return return_c([] {
auto lstrides = Strides{};
if(shape{}.broadcasted())
{
if(strides[i] != 0)
index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{
s[j] = strides[i];
j++;
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
}
else
{
return not is_sorted(strides.begin(), strides.end(), greater{});
}
else
{
return not is_sorted(lstrides.begin(), lstrides.end(), greater{});
}
});
}
constexpr bool standard() const { return packed() and not transposed(); }
constexpr auto standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); }
......@@ -63,10 +66,10 @@ struct shape
return i;
else
{
const index_int rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++)
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
......@@ -80,11 +83,12 @@ struct shape
}
}
/// Convert single index into a multi-index
constexpr index_array multi(index_int idx) const
{
index_array result;
index_int tidx = idx;
for(std::ptrdiff_t is = result.size() - 1; is > 0; is--)
for(diff_int is = result.size() - 1; is > 0; is--)
{
result[is] = tidx % lens[is];
tidx = tidx / lens[is];
......@@ -92,6 +96,13 @@ struct shape
result[0] = tidx;
return result;
}
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
if(idx.empty())
return 0;
return inner_product(lens.begin() + 1, lens.end(), idx.begin(), idx.back());
}
constexpr shape get_shape() const { return *this; }
......
......@@ -3,28 +3,62 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
namespace migraphx {
template <class T>
struct tensor_view_iterator_read
{
T* view;
constexpr auto& operator()(index_int n) const
{
MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n];
}
};
template <class T, class Shape>
struct tensor_view
{
using type = T;
using type = T;
using shape_type = Shape;
using index_array = typename Shape::index_array;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); }
constexpr auto size() const { return get_shape().elements(); }
template <class U>
constexpr T& operator[](U i) const
struct index_to_offset
{
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space());
return x[get_shape().index(i)];
index_int offset;
template <class U>
constexpr index_to_offset(U i) : offset(Shape{}.index(i))
{
}
};
constexpr T& operator[](MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
return x[ito.offset];
}
constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); }
constexpr T* end() const { return data() + size(); }
constexpr auto begin() const { return iterator{0, {this}}; }
constexpr auto end() const { return iterator{this->size(), {this}}; }
constexpr auto begin_at(index_array i) const
{
MIGRAPHX_ASSERT(get_shape().single(i) < get_shape().elements());
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space());
return iterator{get_shape().single(i), {this}};
}
template <class U>
constexpr tensor_view<U, Shape> with(U* y) const
......@@ -36,6 +70,9 @@ struct tensor_view
T* x;
};
template <class T>
using get_shape_c = typename T::shape_type;
template <class T, class Shape>
constexpr tensor_view<T, Shape> make_tensor_view(T* x, Shape)
{
......
......@@ -6,6 +6,21 @@
namespace migraphx {
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
template <class T>
struct type_identity
{
using type = T;
};
template <bool B, class T = void>
struct enable_if
{
......@@ -20,11 +35,178 @@ struct enable_if<true, T>
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
template <class From, class To>
struct is_convertible : bool_constant<__is_convertible(From, To)>
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_abstract);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_aggregate);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_array);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_class);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_compound);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_const);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_empty);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_enum);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_final);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_floating_point);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_function);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_fundamental);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_integral);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_literal_type);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_lvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_function_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_object_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_object);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pod);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_polymorphic);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_rvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_standard_layout);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivial);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_destructible);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_union);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_unsigned);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_volatile);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_base_of);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_convertible);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_same);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_trivially_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_reference
{
using type = T;
};
template <class T>
struct remove_reference<T&>
{
using type = T;
};
template <class T>
struct remove_reference<T&&>
{
using type = T;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct add_pointer : type_identity<typename remove_reference<T>::type*>
{
};
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
template <class... Ts>
struct common_type;
template <class T>
struct common_type<T>
{
using type = T;
};
template <class T, class U>
struct common_type<T, U>
{
using type = decltype(true ? declval<T>() : declval<U>());
};
template <class T, class U, class... Us>
struct common_type<T, U, Us...>
{
using type = typename common_type<typename common_type<T, U>::type, Us...>::type;
};
template <class... Ts>
using common_type_t = typename common_type<Ts...>::type;
constexpr unsigned long int_max(unsigned long n) { return (1u << (n * 8)) - 1; }
template <class T>
constexpr T numeric_max()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return int_max(sizeof(T)) * 2;
else
return int_max(sizeof(T));
}
else if constexpr(is_same<T, double>{})
return __DBL_MAX__;
else if constexpr(is_same<T, float>{})
return __FLT_MAX__;
else if constexpr(is_same<T, migraphx::half>{})
return __FLT16_MAX__;
else
return 0;
}
template <class T>
constexpr T numeric_lowest()
{
if constexpr(is_integral<T>{})
{
if constexpr(is_unsigned<T>{})
return 0;
else
return -numeric_max<T>() - 1;
}
else
{
return -numeric_max<T>();
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment