Commit eed607a3 authored by Paul's avatar Paul
Browse files

Use static ranks

parent e6686d25
......@@ -9,13 +9,13 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#if 0
#if 1
argument concat(hipStream_t stream,
const migraphx::shape&,
std::vector<migraphx::argument> args_vec,
std::vector<std::size_t> offsets_vec)
{
static constexpr const std::size_t limit = 10;
static constexpr const std::size_t limit = 6;
if (offsets_vec.size() > limit)
MIGRAPHX_THROW("Too many arguments to concat");
std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements();
......@@ -23,7 +23,7 @@ argument concat(hipStream_t stream,
hip_visit_all<limit+1>(args_vec)([&](auto args) {
auto output = args.back();
auto ninputs = args.size() - 1;
gs_launch(stream, nelements)([=](auto x) {
gs_launch(stream, nelements)([=](auto i) {
for(std::size_t j = 0;j < ninputs;j++)
{
auto&& arg = args[j];
......
......@@ -124,23 +124,26 @@ hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
return result;
}
using hip_index = hip_vector<std::size_t, 5>;
template<std::size_t N>
struct hip_shape
{
hip_vector<std::size_t, 5> lens = {};
hip_vector<std::size_t, 5> strides = {};
using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
std::size_t elements = 0;
bool standard = false;
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s)
: lens(s.lens().begin(), s.lens().end()),
strides(s.strides().begin(), s.strides().end()),
elements(s.elements()),
: elements(s.elements()),
standard(s.standard())
{
assert(s.lens().size() == N);
assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
......@@ -183,7 +186,7 @@ struct hip_shape
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const
{
hip_index result(lens.size());
hip_index result;
std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++)
{
......@@ -194,13 +197,13 @@ struct hip_shape
}
};
template <class T>
template <class T, std::size_t N>
struct hip_tensor_view
{
__device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; }
......@@ -214,34 +217,42 @@ struct hip_tensor_view
private:
T* d = nullptr;
hip_shape s{};
hip_shape<N> s{};
};
template <class T>
hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x)
template <std::size_t N, class T>
hip_tensor_view<T, N> make_hip_tensor_view(tensor_view<T> x)
{
return x;
}
template <std::size_t N, class T>
hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
template <std::size_t N, std::size_t M, class T>
hip_vector<hip_tensor_view<T, N>, M> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{
hip_vector<hip_tensor_view<T>, N> result(x.size());
hip_vector<hip_tensor_view<T, N>, M> result(x.size());
std::transform(
x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view(y); });
x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view<N>(y); });
return result;
}
template <class... Ts>
auto hip_visit_all(Ts&&... xs)
template <class T, class... Ts>
auto hip_visit_all(T&& x, Ts&&... xs)
{
return [&](auto f) { visit_all(xs...)([&](auto... vs) { f(make_hip_tensor_view(vs)...); }); };
return [&](auto f) {
visit_tensor_size(x.get_shape().lens().size(), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip_tensor_view<dim>(vs)...); });
});
};
}
template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x)
{
return [&](auto f) { visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<N>(v)); }); };
return [&](auto f) {
visit_tensor_size(x.front().get_shape().lens().size(), [&](auto dim) {
visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<dim, N>(v)); });
});
};
}
template <std::size_t NDim>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment