"vscode:/vscode.git/clone" did not exist on "150234bd6ad9f78c3603168c3f9371651fba5054"
Commit eed607a3 authored by Paul's avatar Paul
Browse files

Use static ranks

parent e6686d25
...@@ -9,13 +9,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,13 +9,13 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
#if 0 #if 1
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const migraphx::shape&, const migraphx::shape&,
std::vector<migraphx::argument> args_vec, std::vector<migraphx::argument> args_vec,
std::vector<std::size_t> offsets_vec) std::vector<std::size_t> offsets_vec)
{ {
static constexpr const std::size_t limit = 10; static constexpr const std::size_t limit = 6;
if (offsets_vec.size() > limit) if (offsets_vec.size() > limit)
MIGRAPHX_THROW("Too many arguments to concat"); MIGRAPHX_THROW("Too many arguments to concat");
std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements(); std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements();
...@@ -23,7 +23,7 @@ argument concat(hipStream_t stream, ...@@ -23,7 +23,7 @@ argument concat(hipStream_t stream,
hip_visit_all<limit+1>(args_vec)([&](auto args) { hip_visit_all<limit+1>(args_vec)([&](auto args) {
auto output = args.back(); auto output = args.back();
auto ninputs = args.size() - 1; auto ninputs = args.size() - 1;
gs_launch(stream, nelements)([=](auto x) { gs_launch(stream, nelements)([=](auto i) {
for(std::size_t j = 0;j < ninputs;j++) for(std::size_t j = 0;j < ninputs;j++)
{ {
auto&& arg = args[j]; auto&& arg = args[j];
......
...@@ -124,23 +124,26 @@ hip_vector<T, N> to_hip_vector(const std::vector<T>& x) ...@@ -124,23 +124,26 @@ hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
return result; return result;
} }
using hip_index = hip_vector<std::size_t, 5>;
template<std::size_t N>
struct hip_shape struct hip_shape
{ {
hip_vector<std::size_t, 5> lens = {}; using hip_index = hip_array<std::size_t, N>;
hip_vector<std::size_t, 5> strides = {}; hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
std::size_t elements = 0; std::size_t elements = 0;
bool standard = false; bool standard = false;
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
hip_shape(const shape& s) hip_shape(const shape& s)
: lens(s.lens().begin(), s.lens().end()), : elements(s.elements()),
strides(s.strides().begin(), s.strides().end()),
elements(s.elements()),
standard(s.standard()) standard(s.standard())
{ {
assert(s.lens().size() == N);
assert(s.strides().size() == N);
std::copy(s.lens().begin(), s.lens().end(), lens.begin());
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
...@@ -183,7 +186,7 @@ struct hip_shape ...@@ -183,7 +186,7 @@ struct hip_shape
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const
{ {
hip_index result(lens.size()); hip_index result;
std::size_t tidx = idx; std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++) for(std::size_t is = 0; is < result.size(); is++)
{ {
...@@ -194,13 +197,13 @@ struct hip_shape ...@@ -194,13 +197,13 @@ struct hip_shape
} }
}; };
template <class T> template <class T, std::size_t N>
struct hip_tensor_view struct hip_tensor_view
{ {
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {} __device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; } MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; } MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; }
...@@ -214,34 +217,42 @@ struct hip_tensor_view ...@@ -214,34 +217,42 @@ struct hip_tensor_view
private: private:
T* d = nullptr; T* d = nullptr;
hip_shape s{}; hip_shape<N> s{};
}; };
template <class T> template <std::size_t N, class T>
hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x) hip_tensor_view<T, N> make_hip_tensor_view(tensor_view<T> x)
{ {
return x; return x;
} }
template <std::size_t N, class T> template <std::size_t N, std::size_t M, class T>
hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x) hip_vector<hip_tensor_view<T, N>, M> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{ {
hip_vector<hip_tensor_view<T>, N> result(x.size()); hip_vector<hip_tensor_view<T, N>, M> result(x.size());
std::transform( std::transform(
x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view(y); }); x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view<N>(y); });
return result; return result;
} }
template <class... Ts> template <class T, class... Ts>
auto hip_visit_all(Ts&&... xs) auto hip_visit_all(T&& x, Ts&&... xs)
{ {
return [&](auto f) { visit_all(xs...)([&](auto... vs) { f(make_hip_tensor_view(vs)...); }); }; return [&](auto f) {
visit_tensor_size(x.get_shape().lens().size(), [&](auto dim) {
visit_all(x, xs...)([&](auto... vs) { f(make_hip_tensor_view<dim>(vs)...); });
});
};
} }
template <std::size_t N, class T> template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x) auto hip_visit_all(const std::vector<T>& x)
{ {
return [&](auto f) { visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<N>(v)); }); }; return [&](auto f) {
visit_tensor_size(x.front().get_shape().lens().size(), [&](auto dim) {
visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<dim, N>(v)); });
});
};
} }
template <std::size_t NDim> template <std::size_t NDim>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment