"...resnet50_onnxruntime_migraphx.git" did not exist on "e7035bdfc2b2bccbbf758e91525b0dd862de7af6"
Commit e6686d25 authored by Paul's avatar Paul
Browse files

Formatting

parent 7a3ec119
...@@ -209,13 +209,14 @@ template <class T> ...@@ -209,13 +209,14 @@ template <class T>
auto visit_all(const std::vector<T>& x) auto visit_all(const std::vector<T>& x)
{ {
auto&& s = x.front().get_shape(); auto&& s = x.front().get_shape();
if(!std::all_of(x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); })) if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
return [&](auto v) { return [&](auto v) {
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
using type = typename decltype(as)::type; using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result; std::vector<tensor_view<type>> result;
for(const auto& y:x) for(const auto& y : x)
result.push_back(make_view(y.get_shape(), as.from(y.data()))); result.push_back(make_view(y.get_shape(), as.from(y.data())));
v(result); v(result);
}); });
......
...@@ -45,7 +45,7 @@ argument concat(hipStream_t stream, ...@@ -45,7 +45,7 @@ argument concat(hipStream_t stream,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
auto ninputs = args.size() - 1; auto ninputs = args.size() - 1;
for(std::size_t j = 0;j < ninputs;j++) for(std::size_t j = 0; j < ninputs; j++)
{ {
auto&& arg = args[j]; auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements(); std::size_t nelements = arg.get_shape().elements();
...@@ -60,7 +60,6 @@ argument concat(hipStream_t stream, ...@@ -60,7 +60,6 @@ argument concat(hipStream_t stream,
return args.back(); return args.back();
} }
// argument concat(hipStream_t stream, // argument concat(hipStream_t stream,
// const migraphx::shape& output_shape, // const migraphx::shape& output_shape,
// std::vector<migraphx::argument> args, // std::vector<migraphx::argument> args,
......
...@@ -62,18 +62,16 @@ struct hip_array ...@@ -62,18 +62,16 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); } MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); } MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
}; };
template <class T, std::size_t N> template <class T, std::size_t N>
struct hip_vector struct hip_vector
{ {
MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default; MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default;
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s) MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s) : len(s) {}
: len(s) template <class Iterator>
{}
template<class Iterator>
__device__ __host__ hip_vector(Iterator start, Iterator last) __device__ __host__ hip_vector(Iterator start, Iterator last)
{ {
auto it = std::copy(start, last, d); auto it = std::copy(start, last, d);
...@@ -103,22 +101,22 @@ struct hip_vector ...@@ -103,22 +101,22 @@ struct hip_vector
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); } MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); } MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
template<class U> template <class U>
MIGRAPHX_DEVICE_CONSTEXPR void push_back(U&& x) MIGRAPHX_DEVICE_CONSTEXPR void push_back(U&& x)
{ {
d[len] = static_cast<U&&>(x); d[len] = static_cast<U&&>(x);
len++; len++;
} }
private: private:
T d[N] = {}; T d[N] = {};
std::size_t len = 0; std::size_t len = 0;
}; };
template<std::size_t N, class T> template <std::size_t N, class T>
hip_vector<T, N> to_hip_vector(const std::vector<T>& x) hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
{ {
hip_vector<T, N> result(x.size()); hip_vector<T, N> result(x.size());
...@@ -138,8 +136,12 @@ struct hip_shape ...@@ -138,8 +136,12 @@ struct hip_shape
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
hip_shape(const shape& s) hip_shape(const shape& s)
: lens(s.lens().begin(), s.lens().end()), strides(s.strides().begin(), s.strides().end()), elements(s.elements()), standard(s.standard()) : lens(s.lens().begin(), s.lens().end()),
{} strides(s.strides().begin(), s.strides().end()),
elements(s.elements()),
standard(s.standard())
{
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
{ {
...@@ -153,7 +155,7 @@ struct hip_shape ...@@ -153,7 +155,7 @@ struct hip_shape
{ {
std::size_t idx = 0; std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++) for(std::size_t i = 0; i < x.size(); i++)
idx += *(x.begin()+i) * strides[i]; idx += *(x.begin() + i) * strides[i];
return idx; return idx;
} }
...@@ -192,13 +194,11 @@ struct hip_shape ...@@ -192,13 +194,11 @@ struct hip_shape
} }
}; };
template<class T> template <class T>
struct hip_tensor_view struct hip_tensor_view
{ {
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x) __device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
: d(x.data()), s(x.get_shape())
{}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; } MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; }
...@@ -210,48 +210,38 @@ struct hip_tensor_view ...@@ -210,48 +210,38 @@ struct hip_tensor_view
MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d+size(); } MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d + size(); }
private: private:
T* d = nullptr; T* d = nullptr;
hip_shape s{}; hip_shape s{};
}; };
template<class T> template <class T>
hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x) hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x)
{ {
return x; return x;
} }
template<std::size_t N, class T> template <std::size_t N, class T>
hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x) hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{ {
hip_vector<hip_tensor_view<T>, N> result(x.size()); hip_vector<hip_tensor_view<T>, N> result(x.size());
std::transform(x.begin(), x.end(), result.begin(), [&](auto y) { std::transform(
return make_hip_tensor_view(y); x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view(y); });
});
return result; return result;
} }
template<class... Ts> template <class... Ts>
auto hip_visit_all(Ts&&... xs) auto hip_visit_all(Ts&&... xs)
{ {
return [&](auto f) { return [&](auto f) { visit_all(xs...)([&](auto... vs) { f(make_hip_tensor_view(vs)...); }); };
visit_all(xs...)([&](auto... vs) {
f(make_hip_tensor_view(vs)...);
});
};
} }
template <std::size_t N, class T> template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x) auto hip_visit_all(const std::vector<T>& x)
{ {
return [&](auto f) { return [&](auto f) { visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<N>(v)); }); };
visit_all(x)([&](auto&& v) {
f(make_hip_tensor_views<N>(v));
});
};
} }
template <std::size_t NDim> template <std::size_t NDim>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment