Commit 7a3ec119 authored by Paul's avatar Paul
Browse files

Initial batch concat

parent 8b6a5cda
......@@ -205,6 +205,23 @@ auto visit_all(T&& x, Ts&&... xs)
};
}
template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
for(const auto& y:x)
result.push_back(make_view(y.get_shape(), as.from(y.data())));
v(result);
});
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -9,29 +9,81 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#if 0
argument concat(hipStream_t stream,
const migraphx::shape& output_shape,
const migraphx::shape&,
std::vector<migraphx::argument> args_vec,
std::vector<std::size_t> offsets_vec)
{
static constexpr const std::size_t limit = 10;
if (offsets_vec.size() > limit)
MIGRAPHX_THROW("Too many arguments to concat");
std::size_t nelements = std::max_element(args_vec.begin(), std::prev(args_vec.end()), by(std::less<>{}, [&](auto&& x) { return x.get_shape().elements(); }))->get_shape().elements();
auto offsets = to_hip_vector<limit>(offsets_vec);
hip_visit_all<limit+1>(args_vec)([&](auto args) {
auto output = args.back();
auto ninputs = args.size() - 1;
gs_launch(stream, nelements)([=](auto x) {
for(std::size_t j = 0;j < ninputs;j++)
{
auto&& arg = args[j];
if (i >= arg.size())
continue;
auto idx = output.get_shape().index(arg.get_shape().multi(i));
output.data()[idx + offsets[j]] = arg.data()[i];
}
});
});
return args_vec.back();
}
#else
argument concat(hipStream_t stream,
const migraphx::shape&,
std::vector<migraphx::argument> args,
std::vector<std::size_t> offsets)
{
for(std::size_t l = 0; l < args.size() - 1; l++)
auto ninputs = args.size() - 1;
for(std::size_t j = 0;j < ninputs;j++)
{
auto argl = args[l];
std::size_t nelements = argl.get_shape().elements();
visit_all(args.back(), argl)([&](auto output, auto input) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto* outptr = output.data() + offsets[l];
const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
auto&& arg = args[j];
std::size_t nelements = arg.get_shape().elements();
auto offset = offsets[j];
hip_visit_all(args.back(), arg)([&](auto output, auto input) {
gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().index(input.get_shape().multi(i));
output.data()[idx + offset] = input.data()[i];
});
});
}
return args.back();
}
// argument concat(hipStream_t stream,
// const migraphx::shape& output_shape,
// std::vector<migraphx::argument> args,
// std::vector<std::size_t> offsets)
// {
// for(std::size_t l = 0; l < args.size() - 1; l++)
// {
// auto argl = args[l];
// std::size_t nelements = argl.get_shape().elements();
// visit_all(args.back(), argl)([&](auto output, auto input) {
// visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
// auto* outptr = output.data() + offsets[l];
// const auto* inptr = input.data();
// hip_tensor_descriptor<ndim> desc_input(input.get_shape());
// hip_tensor_descriptor<ndim> desc_output(output.get_shape());
// gs_launch(stream, nelements)(
// [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
// });
// });
// }
// return args.back();
// }
#endif
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -4,12 +4,15 @@
#include <hip/hip_runtime.h>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_DEVICE_CONSTEXPR constexpr __device__ __host__
template <class F>
void visit_tensor_size(std::size_t n, F f)
{
......@@ -44,15 +47,217 @@ void visit_tensor_size(std::size_t n, F f)
}
}
template <size_t NDim>
struct hip_index
template <class T, std::size_t N>
struct hip_array
{
size_t d[NDim];
__device__ __host__ size_t& operator[](size_t i) { return d[i]; }
__device__ __host__ size_t operator[](size_t i) const { return d[i]; }
T d[N];
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::integral_constant<std::size_t, N> size() const { return {}; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); }
};
template <size_t NDim>
template <class T, std::size_t N>
struct hip_vector
{
MIGRAPHX_DEVICE_CONSTEXPR hip_vector() = default;
MIGRAPHX_DEVICE_CONSTEXPR hip_vector(std::size_t s)
: len(s)
{}
template<class Iterator>
__device__ __host__ hip_vector(Iterator start, Iterator last)
{
auto it = std::copy(start, last, d);
len = std::distance(d, it);
}
__device__ __host__ hip_vector(std::initializer_list<T> x)
{
auto it = std::copy(x.begin(), x.end(), d);
len = x.size();
}
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& operator[](std::size_t i) const { return d[i]; }
MIGRAPHX_DEVICE_CONSTEXPR T& front() { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& front() const { return d[0]; }
MIGRAPHX_DEVICE_CONSTEXPR T& back() { return d[size() - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR const T& back() const { return d[size() - 1]; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return len; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() { return d; }
MIGRAPHX_DEVICE_CONSTEXPR const T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d+size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d+size(); }
template<class U>
MIGRAPHX_DEVICE_CONSTEXPR void push_back(U&& x)
{
d[len] = static_cast<U&&>(x);
len++;
}
private:
T d[N] = {};
std::size_t len = 0;
};
template<std::size_t N, class T>
hip_vector<T, N> to_hip_vector(const std::vector<T>& x)
{
hip_vector<T, N> result(x.size());
std::copy(x.begin(), x.end(), result.begin());
return result;
}
using hip_index = hip_vector<std::size_t, 5>;
struct hip_shape
{
hip_vector<std::size_t, 5> lens = {};
hip_vector<std::size_t, 5> strides = {};
std::size_t elements = 0;
bool standard = false;
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s)
: lens(s.lens().begin(), s.lens().end()), strides(s.strides().begin(), s.strides().end()), elements(s.elements()), standard(s.standard())
{}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += x[i] * strides[i];
return idx;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += *(x.begin()+i) * strides[i];
return idx;
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::size_t i) const
{
if(this->standard)
return i;
else
{
const std::size_t rank = this->lens.size();
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens.size(); j++)
{
const std::size_t k = rank - j - 1;
const std::size_t stride = this->strides[k];
const std::size_t len = this->lens[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
}
MIGRAPHX_DEVICE_CONSTEXPR hip_index multi(std::size_t idx) const
{
hip_index result(lens.size());
std::size_t tidx = idx;
for(std::size_t is = 0; is < result.size(); is++)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
};
template<class T>
struct hip_tensor_view
{
__device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x)
: d(x.data()), s(x.get_shape())
{}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; }
MIGRAPHX_DEVICE_CONSTEXPR T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) const { return d[s.index(i)]; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d+size(); }
private:
T* d = nullptr;
hip_shape s{};
};
template<class T>
hip_tensor_view<T> make_hip_tensor_view(tensor_view<T> x)
{
return x;
}
template<std::size_t N, class T>
hip_vector<hip_tensor_view<T>, N> make_hip_tensor_views(const std::vector<tensor_view<T>>& x)
{
hip_vector<hip_tensor_view<T>, N> result(x.size());
std::transform(x.begin(), x.end(), result.begin(), [&](auto y) {
return make_hip_tensor_view(y);
});
return result;
}
template<class... Ts>
auto hip_visit_all(Ts&&... xs)
{
return [&](auto f) {
visit_all(xs...)([&](auto... vs) {
f(make_hip_tensor_view(vs)...);
});
};
}
template <std::size_t N, class T>
auto hip_visit_all(const std::vector<T>& x)
{
return [&](auto f) {
visit_all(x)([&](auto&& v) {
f(make_hip_tensor_views<N>(v));
});
};
}
template <std::size_t NDim>
using hip_tensor_index = hip_array<std::size_t, NDim>;
template <std::size_t NDim>
struct hip_tensor_descriptor
{
__device__ __host__ hip_tensor_descriptor() = default;
......@@ -63,26 +268,26 @@ struct hip_tensor_descriptor
std::copy(s.strides().begin(), s.strides().end(), strides);
}
__device__ __host__ hip_index<NDim> multi(size_t idx) const
__device__ __host__ hip_tensor_index<NDim> multi(std::size_t idx) const
{
hip_index<NDim> result{};
size_t tidx = idx;
for(size_t is = 0; is < NDim; is++)
hip_tensor_index<NDim> result{};
std::size_t tidx = idx;
for(std::size_t is = 0; is < NDim; is++)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
__device__ __host__ size_t linear(hip_index<NDim> s) const
__device__ __host__ std::size_t linear(hip_tensor_index<NDim> s) const
{
size_t idx = 0;
for(size_t i = 0; i < NDim; i++)
std::size_t idx = 0;
for(std::size_t i = 0; i < NDim; i++)
idx += s[i] * strides[i];
return idx;
}
size_t lens[NDim] = {};
size_t strides[NDim] = {};
std::size_t lens[NDim] = {};
std::size_t strides[NDim] = {};
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment