Commit cf8cc835 authored by Paul's avatar Paul
Browse files

Use hip_shape in nary

parent 33a41ba0
......@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#if 1
#if 0
argument concat(hipStream_t stream,
const migraphx::shape&,
std::vector<migraphx::argument> args_vec,
......
......@@ -13,43 +13,38 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
using vec4 = T __attribute__((ext_vector_type(4)));
// template <class T>
// using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
__device__ __host__ vec4<T>* as_vec4(T* x)
template <class T, std::size_t N>
using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{
return reinterpret_cast<vec4<T>*>(x);
return reinterpret_cast<vec<T, N>*>(x);
}
template <class T>
__device__ __host__ T* as_pointer(vec4<T>* x)
template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec<T, N>* x)
{
return reinterpret_cast<T*>(x);
}
template <class... Ts>
auto pack_vec4(Ts... xs)
template <std::size_t N, class... Ts>
auto pack_vec(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); };
return [=](auto f, std::size_t n) { return f(as_vec<4>(xs)[n]...); };
}
template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) {
auto idx = output.get_shape().multi(i);
output[i] = f(inputs[idx]...);
});
});
}
......@@ -75,10 +70,10 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* zp = as_vec4(device_cast(input3.data()));
auto* outp = as_vec4(device_cast(output.data()));
auto* xp = as_vec<4>(device_cast(input1.data()));
auto* yp = as_vec<4>(device_cast(input2.data()));
auto* zp = as_vec<4>(device_cast(input3.data()));
auto* outp = as_vec<4>(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -87,7 +82,7 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
MIGRAPHX_DEVICE_SHARED vec<type, 4> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
......@@ -100,9 +95,9 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> y = yp[i];
vec4<type> out = outp[i];
vec<type, 4> x = xp[i];
vec<type, 4> y = yp[i];
vec<type, 4> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], y[j], b);
......@@ -181,9 +176,9 @@ void binary_broadcast_vec_impl(
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data()));
auto* outp = as_vec4(device_cast(output.data()));
auto* xp = as_vec<4>(device_cast(input1.data()));
auto* yp = as_vec<4>(device_cast(input2.data()));
auto* outp = as_vec<4>(device_cast(output.data()));
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
......@@ -192,7 +187,7 @@ void binary_broadcast_vec_impl(
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
MIGRAPHX_DEVICE_SHARED vec<type, 4> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
......@@ -205,8 +200,8 @@ void binary_broadcast_vec_impl(
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
vec<type, 4> x = xp[i];
vec<type, 4> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], b);
......@@ -270,10 +265,10 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data()));
auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i];
vec<type, 4> out = outp[i];
data(
[&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++)
......@@ -290,13 +285,11 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(device_cast(inputs.data())...);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
std::size_t nelements = result.get_shape().elements();
hip_visit_all(result, args...)([&](auto output, auto... inputs) {
gs_launch(stream, nelements)([=](auto i) {
output.data()[i] = f(inputs.data()[i]...);
});
});
}
......
......@@ -5,6 +5,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -64,6 +65,30 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{
T result = 0;
for(std::size_t i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR T product() const
{
T result = 1;
for(std::size_t i = 0; i < N; i++)
result *= d[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0;i < N;i++)
result[i] = x[i] * y[i];
return result;
}
};
template <class T, std::size_t N>
......@@ -130,12 +155,11 @@ struct hip_shape
using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {};
std::size_t elements = 0;
bool standard = false;
__device__ __host__ hip_shape() = default;
hip_shape(const shape& s) : elements(s.elements()), standard(s.standard())
hip_shape(const shape& s) : standard(s.standard())
{
assert(s.lens().size() == N);
assert(s.strides().size() == N);
......@@ -143,12 +167,14 @@ struct hip_shape
std::copy(s.strides().begin(), s.strides().end(), strides.begin());
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const
{
return lens.product();
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
{
std::size_t idx = 0;
for(std::size_t i = 0; i < x.size(); i++)
idx += x[i] * strides[i];
return idx;
return x.dot(strides);
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
......@@ -173,9 +199,10 @@ struct hip_shape
const std::size_t k = rank - j - 1;
const std::size_t stride = this->strides[k];
const std::size_t len = this->lens[k];
const std::size_t idx = (i % (s * len)) / s;
const std::size_t slen = s * len;
const std::size_t idx = (i % slen) / s;
result += stride * idx;
s *= len;
s = slen;
}
return result;
}
......@@ -197,23 +224,25 @@ struct hip_shape
template <class T, std::size_t N>
struct hip_tensor_view
{
using value_type = device_type<T>;
__device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {}
__host__ hip_tensor_view(tensor_view<T> x) : d(device_cast(x.data())), s(x.get_shape()) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements(); }
MIGRAPHX_DEVICE_CONSTEXPR T* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) const { return d[s.index(i)]; }
template<class U>
MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const { return d[s.index(i)]; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR value_type* end() const { return d + size(); }
private:
T* d = nullptr;
value_type* d = nullptr;
hip_shape<N> s{};
};
......
......@@ -408,8 +408,8 @@ void fuse_ops::apply(program& p) const
// clang-format off
match::find_matches(p, find_triadd{});
match::find_matches(p,
find_conv_bias_relu{ctx},
find_conv_bias{ctx},
// find_conv_bias_relu{ctx},
// find_conv_bias{ctx},
find_add_relu{}
);
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment