"src/vscode:/vscode.git/clone" did not exist on "5703f6c1dc2afb8bade0fe2fce69567ffe1d0628"
Commit e6686d25 authored by Paul's avatar Paul
Browse files

Formatting

parent 7a3ec119
......@@ -209,13 +209,14 @@ template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
for(const auto& y:x)
for(const auto& y : x)
result.push_back(make_view(y.get_shape(), as.from(y.data())));
v(result);
});
......
......@@ -45,7 +45,7 @@ argument concat(hipStream_t stream,
std::vector<std::size_t> offsets)
{
auto ninputs = args.size() - 1;
for(std::size_t j = 0;j < ninputs;j++)
for(std::size_t j = 0; j < ninputs; j++)
{
auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
......@@ -60,7 +60,6 @@ argument concat(hipStream_t stream,
return args.back();
}
// argument concat(hipStream_t stream,
// const migraphx::shape& output_shape,
// std::vector<migraphx::argument> args,
......
......@@ -62,18 +62,16 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
};
template <class T, std::size_t N>
struct hip_vector
{
MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default;
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s)
: len(s)
{}
template<class Iterator>
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s) : len(s) {}
template <class Iterator>
__device__ __host__ hip_vector(Iterator start, Iterator last)
{
auto it = std::copy(start, last, d);
......@@ -103,22 +101,22 @@ struct hip_vector
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
template<class U>
template <class U>
MIGRAPHX_DEVICE_CONSTEXPR void push_back(U&& x)
{
d[len] = static_cast<U&&>(x);
len++;
}
private:
private:
T d[N] = {};
std::size_t len = 0;
};
template<std::size_t N, class T>
template <std::size_t N, class T>
hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
{
hip_vector<T, N> result(x.size());
......@@ -138,8 +136,12 @@ struct hip_shape
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s)
: lens(s.lens().begin(), s.lens().end()), strides(s.strides().begin(), s.strides().end()), elements(s.elements()), standard(s.standard())
{}
: lens(s.lens().begin(), s.lens().end()),
strides(s.strides().begin(), s.strides().end()),
elements(s.elements()),
standard(s.standard())
{
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
{
......@@ -153,7 +155,7 @@ struct hip_shape
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += *(x.begin()+i) * strides[i];
idx += *(x.begin() + i) * strides[i];
return idx;
}
......@@ -192,13 +194,11 @@ struct hip_shape
}
};
template<class T>
template <class T>
struct hip_tensor_view
{
__device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x)
: d(x.data()), s(x.get_shape())
{}
__device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; }
......@@ -210,48 +210,38 @@ struct hip_tensor_view
MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d + size(); }
private:
private:
T* d = nullptr;
hip_shape s{};
};
template<class T>
template <class T>
hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x)
{
return x;
}
template<std::size_t N, class T>
template <std::size_t N, class T>
hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{
hip_vector<hip_tensor_view<T>, N> result(x.size());
std::transform(x.begin(), x.end(), result.begin(), [&](auto y) {
return make_hip_tensor_view(y);
});
std::transform(
x.begin(), x.end(), result.begin(), [&](auto y) { return make_hip_tensor_view(y); });
return result;
}
template<class... Ts>
template <class... Ts>
auto hip_visit_all(Ts&&... xs)
{
return [&](auto f) {
visit_all(xs...)([&](auto... vs) {
f(make_hip_tensor_view(vs)...);
});
};
return [&](auto f) { visit_all(xs...)([&](auto... vs) { f(make_hip_tensor_view(vs)...); }); };
}
template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x)
{
return [&](auto f) {
visit_all(x)([&](auto&& v) {
f(make_hip_tensor_views<N>(v));
});
};
return [&](auto f) { visit_all(x)([&](auto&& v) { f(make_hip_tensor_views<N>(v)); }); };
}
template <std::size_t NDim>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment