Unverified Commit 31e63991 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add iterators to kernels tensor_view and fix roialign to work with non-standard shape (#1126)

This adds iterators to tensor_view, which can allow kernels to work with non-standard shapes like for roialign.

To improve the performance of indexing when using the iterators, the shape class was updated to use integral_constants since the compiler doesn't always fold the const values. An integral_constant will at least enforce that in the AST.

Finally, since index calculations with single integers are improved, I also updated pointwise to use single index rather than multi index. There is about 4% improvement in some cases.
parent 2d1efd69
......@@ -70,5 +70,11 @@ using index_constant = integral_constant<index_int, N>;
template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(std::declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
......@@ -39,10 +39,8 @@ template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{
preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
});
idx.global_stride(out.get_shape().elements(),
[&](auto i) { out[i] = implicit_conversion(f(ps[i]...)); });
});
}
......
......@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>;
constexpr auto size = s.element_space();
if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{});
......
......@@ -19,7 +19,7 @@ struct max_pool
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t)
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int)
{
return (x);
}
......@@ -36,21 +36,19 @@ struct avg_pool
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, std::size_t y)
MIGRAPHX_DEVICE_CONSTEXPR T final(T x, index_int y)
{
return (y == 0) ? 0.0 : (x / y);
}
};
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
const array<std::size_t, 2>& dims,
array<float, 2> xy,
Op pooling)
template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
const Iterator data, const array<index_int, 2>& dims, array<float, 2> xy, Op pooling)
{
array<int, 2> low{};
array<int, 2> high{};
for(std::size_t ii = 0; ii < xy.size(); ++ii)
for(index_int ii = 0; ii < xy.size(); ++ii)
{
if(xy[ii] < -1.0f or xy[ii] > dims[ii])
{
......@@ -65,36 +63,36 @@ MIGRAPHX_DEVICE_CONSTEXPR T bilinear_interpolate(const T* data,
xy[ii] = high[ii] = low[ii] = dims[ii] - 1;
}
}
array<std::size_t, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
array<index_int, 4> locs = {low[0] * dims[1] + low[1],
low[0] * dims[1] + high[1],
high[0] * dims[1] + low[1],
high[0] * dims[1] + high[1]};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<T, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
float ly = xy[0] - low[0];
float lx = xy[1] - low[1];
float hy = 1.0f - ly;
float hx = 1.0f - lx;
array<typename Iterator::value_type, 4> ws = {hy * hx, hy * lx, ly * hx, ly * lx};
auto v01 = pooling(data[locs[0]] * ws[0], data[locs[1]] * ws[1]);
auto v23 = pooling(data[locs[2]] * ws[2], data[locs[3]] * ws[3]);
return pooling(v01, v23);
}
template <class T, class Op>
MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data,
const array<float, 2>& roi_starts,
const array<float, 2>& bin_size,
const array<int, 2>& idx,
const array<std::size_t, 2>& bin_grid_size,
const array<std::size_t, 2>& dims,
float roi_offset,
Op op)
template <class Iterator, class Op>
MIGRAPHX_DEVICE_CONSTEXPR auto calc_pooling(const Iterator& data,
const array<float, 2>& roi_starts,
const array<float, 2>& bin_size,
const array<int, 2>& idx,
const array<index_int, 2>& bin_grid_size,
const array<index_int, 2>& dims,
float roi_offset,
Op op)
{
T output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
typename Iterator::value_type output_val = op.init();
const int64_t count = bin_grid_size[0] * bin_grid_size[1];
dfor(bin_grid_size[0], bin_grid_size[1])([&](auto iy, auto ix) {
array<std::size_t, 2> id = {iy, ix};
array<index_int, 2> id = {iy, ix};
array<float, 2> locs =
roi_starts + idx * bin_size + bin_size * (id + 0.5f) / bin_grid_size + roi_offset;
......@@ -122,19 +120,19 @@ constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{
auto index = make_index();
const auto* x = x_t.data();
const auto* rois = rois_t.data();
const auto* ind = ind_t.data();
auto index = make_index();
const auto x = x_t.begin();
const auto rois = rois_t.begin();
const auto ind = ind_t.begin();
auto* out_ptr = y_t.data();
auto out_ptr = y_t.begin();
// input shape
auto x_lens = x_t.get_shape().lens;
auto channel_num = x_lens[1];
// input dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
array<std::size_t, 2> in_dims = {x_lens[2], x_lens[3]};
array<index_int, 2> in_dims = {x_lens[2], x_lens[3]};
const auto stride = index.nglobal();
auto out_s = y_t.get_shape();
......@@ -142,8 +140,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
// output dims of height and width, in all 2-dim arrays, the first dim
// is for height and second dim is for width
const auto& out_lens = out_s.lens;
array<std::size_t, 2> out_dims = {out_lens[2], out_lens[3]};
const auto& out_lens = out_s.lens;
array<index_int, 2> out_dims = {out_lens[2], out_lens[3]};
for(index_int i = index.global; i < out_s.elements(); i += stride)
{
......@@ -153,8 +151,8 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
int ph = idx[2];
int pw = idx[3];
const auto* offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
const auto offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * s.spatial_scale};
......@@ -163,9 +161,9 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
array<float, 2> roi_size{};
array<float, 2> bin_size{};
array<std::size_t, 2> bin_grid_size{};
array<index_int, 2> bin_grid_size{};
for(std::size_t ii = 0; ii < roi_size.size(); ++ii)
for(index_int ii = 0; ii < roi_size.size(); ++ii)
{
roi_size[ii] = roi_ends[ii] - roi_starts[ii];
roi_size[ii] = max(roi_size[ii], 1.0f);
......@@ -175,7 +173,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
(s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
}
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
const auto offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(s.is_avg_pooling)
{
out_ptr[i] = calc_pooling(offset_x,
......
......@@ -17,35 +17,38 @@ struct shape
constexpr shape(Lens l, Strides s) : lens(l), strides(s) {}
constexpr index_int elements() const { return lens.product(); }
constexpr auto elements() const { return _c<Lens{}.product()>; }
constexpr index_int element_space() const { return strides.dot(lens - 1) + 1; }
constexpr auto element_space() const { return _c<Strides{}.dot(Lens{} - 1) + 1>; }
constexpr bool packed() const { return elements() == element_space(); }
constexpr bool broadcasted() const { return strides.product() == 0; }
constexpr bool transposed() const
constexpr auto packed() const { return elements() == element_space(); }
constexpr auto broadcasted() const { return _c<Strides{}.product() == 0>; }
constexpr auto transposed() const
{
if(broadcasted())
{
index_array s;
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
return return_c([] {
auto lstrides = Strides{};
if(shape{}.broadcasted())
{
if(strides[i] != 0)
index_array s{};
index_int j = 0;
for(index_int i = 0; i < s.size(); i++)
{
s[j] = strides[i];
j++;
if(lstrides[i] != 0)
{
s[j] = lstrides[i];
j++;
}
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
}
return not is_sorted(s.begin(), s.begin() + j, greater{});
}
else
{
return not is_sorted(strides.begin(), strides.end(), greater{});
}
else
{
return not is_sorted(lstrides.begin(), lstrides.end(), greater{});
}
});
}
constexpr bool standard() const { return packed() and not transposed(); }
constexpr auto standard() const { return packed() and not transposed(); }
constexpr index_int index(index_array x) const { return x.dot(strides); }
......@@ -63,10 +66,10 @@ struct shape
return i;
else
{
const index_int rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < this->lens.size(); j++)
const auto rank = this->lens.size();
index_int s = 1;
index_int result = 0;
for(index_int j = 0; j < rank; j++)
{
const index_int k = rank - j - 1;
const index_int stride = this->strides[k];
......
......@@ -3,17 +3,30 @@
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/iota_iterator.hpp>
namespace migraphx {
template <class T>
struct tensor_view_iterator_read
{
T* view;
constexpr auto& operator()(std::size_t n) const
{
MIGRAPHX_ASSERT(view != nullptr);
return (*view)[n];
}
};
template <class T, class Shape>
struct tensor_view
{
using type = T;
using shape_type = Shape;
using iterator = basic_iota_iterator<tensor_view_iterator_read<const tensor_view>, index_int>;
constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); }
constexpr auto size() const { return get_shape().elements(); }
template <class U>
constexpr T& operator[](U i) const
......@@ -24,8 +37,8 @@ struct tensor_view
constexpr T* data() const { return x; }
constexpr T* begin() const { return data(); }
constexpr T* end() const { return data() + size(); }
constexpr auto begin() const { return iterator{0, {this}}; }
constexpr auto end() const { return iterator{this->size(), {this}}; }
template <class U>
constexpr tensor_view<U, Shape> with(U* y) const
......
......@@ -6,6 +6,12 @@
namespace migraphx {
template <class T>
struct type_identity
{
using type = T;
};
template <bool B, class T = void>
struct enable_if
{
......@@ -35,6 +41,33 @@ struct is_same<T, T> : true_type
{
};
template <class T>
struct remove_reference
{
using type = T;
};
template <class T>
struct remove_reference<T&>
{
using type = T;
};
template <class T>
struct remove_reference<T&&>
{
using type = T;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct add_pointer : type_identity<typename remove_reference<T>::type*>
{
};
template <class T>
using add_pointer_t = typename add_pointer<T>::type;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx
......
......@@ -6,6 +6,7 @@
namespace migraphx {
using index_int = std::uint32_t;
using diff_int = std::int32_t;
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__ // NOLINT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment