Unverified Commit 31e63991 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add iterators to kernels tensor_view and fix roialign to work with non-standard shape (#1126)

This adds iterators to tensor_view, which can allow kernels to work with non-standard shapes like for roialign.

To improve the performance of indexing when using the iterators, the shape class was updated to use integral_constants since the compiler doesn't always fold the const values. An integral_constant will at least enforce that in the AST.

Finally, since index calculations with single integers are improved, I also updated pointwise to use single index rather than multi index. There is about 4% improvement in some cases.
parent 2d1efd69
...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>; ...@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>;
template <auto V> template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP #endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts> ...@@ -39,10 +39,8 @@ template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(),
auto multi_idx = out.get_shape().multi(i); [&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
}); });
} }
......
...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>; using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = s.element_space();
if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{}) not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
......
...@@ -19,7 +19,7 @@ struct max_pool ...@@ -19,7 +19,7 @@ struct max_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int)
{ {
return (x); return (x);
} }
...@@ -36,21 +36,19 @@ struct avg_pool ...@@ -36,21 +36,19 @@ struct avg_pool
} }
template <class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y) MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{ {
return (y == 0) ? 0.0 : (x / y); return (y == 0) ? 0.0 : (x / y);
} }
}; };
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const array<std::size_t, 2>& dims, const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
array<float, 2> xy,
Op pooling)
{ {
array<int, 2> low{}; array<int, 2> low{};
array<int, 2> high{}; array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii) for(index_int ii = 0; ii < xy.size(); ++ii)
{ {
if(xy[ii] < -1.0f or xy[ii] > dims[ii]) if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{ {
...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data, ...@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
xy[ii] = high[ii] = low[ii] = dims[ii] - 1; xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
} }
} }
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1], array<index_int, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1], low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1], high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]}; high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0]; float ly = xy[0] - low[0];
float lx = xy[1] - low[1]; float lx = xy[1] - low[1];
float hy = 1.0f - ly; float hy = 1.0f - ly;
float hx = 1.0f - lx; float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx}; array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]); auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]); auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23); return pooling(v01, v23);
} }
template <class T, class Op> template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data, MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
const array<float, 2>& roi_starts, const array<float, 2>& roi_starts,
const array<float, 2>& bin_size, const array<float, 2>& bin_size,
const array<int, 2>& idx, const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size, const array<index_int, 2>& bin_grid_size,
const array<std::size_t, 2>& dims, const array<index_int, 2>& dims,
float roi_offset, float roi_offset,
Op op) Op op)
{ {
T output_val = op.init(); typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1]; const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) { dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix}; array<index_int, 2> id = {iy, ix};
array<float, 2> locs = array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset; roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs) ...@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
template <class T, class U, class V, class W, class Settings> template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s) __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{ {
auto index = make_index(); auto index = make_index();
const auto* x = x_t.data(); const auto x = x_t.begin();
const auto* rois = rois_t.data(); const auto rois = rois_t.begin();
const auto* ind = ind_t.data(); const auto ind = ind_t.begin();
auto* out_ptr = y_t.data(); auto out_ptr = y_t.begin();
// input shape // input shape
auto x_lens = x_t.get_shape().lens; auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1]; auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim // input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]}; array<index_int, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal(); const auto stride = index.nglobal();
auto out_s = y_t.get_shape(); auto out_s = y_t.get_shape();
...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
// output dims of height and width, in all 2-dim arrays, the first dim // output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width // is for height and second dim is for width
const auto& out_lens = out_s.lens; const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]}; array<index_int, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride) for(index_int i = index.global; i < out_s.elements(); i += stride)
{ {
...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
int ph = idx[2]; int ph = idx[2];
int pw = idx[3]; int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num); const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale, array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale}; offset_rois[0] * s.spatial_scale};
...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{}; array<index_int, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii) for(index_int ii = 0; ii < roi_size.size(); ++ii)
{ {
roi_size[ii] = roi_ends[ii] - roi_starts[ii]; roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f); roi_size[ii] = max(roi_size[ii], 1.0f);
...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); (s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, out_ptr[i] = calc_pooling(offset_x,
......
...@@ -17,35 +17,38 @@ struct shape ...@@ -17,35 +17,38 @@ struct shape
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {} constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); } constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; } constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr bool packed() const { return elements() == element_space(); } constexpr auto packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; } constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr bool transposed() const constexpr auto transposed() const
{ {
if(broadcasted()) return return_c([] {
{ auto lstrides = Strides{};
index_array s; if(shape{}.broadcasted())
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
if(strides[i] != 0) index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{ {
s[j] = strides[i]; if(lstrides[i] != 0)
j++; {
s[j] = lstrides[i];
j++;
}
} }
return not is_sorted(s.begin(), s.begin() + j, greater{});
} }
return not is_sorted(s.begin(), s.begin() + j, greater{}); else
} {
else return not is_sorted(lstrides.begin(), lstrides.end(), greater{});
{ }
return not is_sorted(strides.begin(), strides.end(), greater{}); });
}
} }
constexpr bool standard() const { return packed() and not transposed(); } constexpr auto standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); } constexpr index_int index(index_array x) const { return x.dot(strides); }
...@@ -63,10 +66,10 @@ struct shape ...@@ -63,10 +66,10 @@ struct shape
return i; return i;
else else
{ {
const index_int rank = this->lens.size(); const auto rank = this->lens.size();
index_int s = 1; index_int s = 1;
index_int result = 0; index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++) for(index_int j = 0; j < rank; j++)
{ {
const index_int k = rank - j - 1; const index_int k = rank - j - 1;
const index_int stride = this->strides[k]; const index_int stride = this->strides[k];
......
...@@ -3,17 +3,30 @@ ...@@ -3,17 +3,30 @@
#include <migraphx/kernels/shape.hpp> #include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
namespace migraphx { namespace migraphx {
template <class T>
struct tensor_view_iterator_read
{
T* view;
constexpr auto& operator()(std::size_t n) const
{
MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n];
}
};
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape; using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
template <class U> template <class U>
constexpr T& operator[](U i) const constexpr T& operator[](U i) const
...@@ -24,8 +37,8 @@ struct tensor_view ...@@ -24,8 +37,8 @@ struct tensor_view
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); } constexpr auto begin() const { return iterator{0, {this}}; }
constexpr T* end() const { return data() + size(); } constexpr auto end() const { return iterator{this->size(), {this}}; }
template <class U> template <class U>
constexpr tensor_view<U, Shape> with(U* y) const constexpr tensor_view<U, Shape> with(U* y) const
......
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
namespace migraphx { namespace migraphx {
template <class T>
struct type_identity
{
using type = T;
};
template <bool B, class T = void> template <bool B, class T = void>
struct enable_if struct enable_if
{ {
...@@ -35,6 +41,33 @@ struct is_same<T, T> : true_type ...@@ -35,6 +41,33 @@ struct is_same<T, T> : true_type
{ {
}; };
template <class T>
struct remove_reference
{
using type = T;
};
template <class T>
struct remove_reference<T&>
{
using type = T;
};
template <class T>
struct remove_reference<T&&>
{
using type = T;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct add_pointer : type_identity<typename remove_reference<T>::type*>
{
};
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
namespace migraphx { namespace migraphx {
using index_int = std::uint32_t; using index_int = std::uint32_t;
using diff_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT #define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment