Commit cf8cc835 authored by Paul's avatar Paul
Browse files

Use hip_shape in nary

parent 33a41ba0
...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -9,7 +9,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
#if 1 #if 0
argument concat(hipStream_t stream, argument concat(hipStream_t stream,
const migraphx::shape&, const migraphx::shape&,
std::vector<migraphx::argument> args_vec, std::vector<migraphx::argument> args_vec,
......
...@@ -13,43 +13,38 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,43 +13,38 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T> // template <class T>
using vec4 = T __attribute__((ext_vector_type(4))); // using vec4 = T __attribute__((ext_vector_type(4)));
template <class T> template <class T, std::size_t N>
__device__ __host__ vec4<T>* as_vec4(T* x) using vec = T __attribute__((ext_vector_type(N)));
template <std::size_t N, class T>
__device__ __host__ vec<T, N>* as_vec(T* x)
{ {
return reinterpret_cast<vec4<T>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T> template <std::size_t N, class T>
__device__ __host__ T* as_pointer(vec4<T>* x) __device__ __host__ T* as_pointer(vec<T, N>* x)
{ {
return reinterpret_cast<T*>(x); return reinterpret_cast<T*>(x);
} }
template <class... Ts> template <std::size_t N, class... Ts>
auto pack_vec4(Ts... xs) auto pack_vec(Ts... xs)
{ {
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); }; return [=](auto f, std::size_t n) { return f(as_vec<4>(xs)[n]...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args) auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); std::size_t nelements = result.get_shape().elements();
visit_all(result, args...)([&](auto output, auto... inputs) { hip_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { gs_launch(stream, nelements)([=](auto i) {
auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, auto idx = output.get_shape().multi(i);
device_cast(inputs.data()))...); output[i] = f(inputs[idx]...);
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = device_cast(output.data());
gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
});
}); });
}); });
} }
...@@ -75,10 +70,10 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -75,10 +70,10 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data())); auto* xp = as_vec<4>(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data())); auto* yp = as_vec<4>(device_cast(input2.data()));
auto* zp = as_vec4(device_cast(input3.data())); auto* zp = as_vec<4>(device_cast(input3.data()));
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -87,7 +82,7 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -87,7 +82,7 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED vec<type, 4> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{ {
...@@ -100,9 +95,9 @@ void trinary_broadcast_vec_impl(hipStream_t stream, ...@@ -100,9 +95,9 @@ void trinary_broadcast_vec_impl(hipStream_t stream,
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx]; auto b = bp[bidx];
vec4<type> x = xp[i]; vec<type, 4> x = xp[i];
vec4<type> y = yp[i]; vec<type, 4> y = yp[i];
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], y[j], b); out[j] = f(x[j], y[j], b);
...@@ -181,9 +176,9 @@ void binary_broadcast_vec_impl( ...@@ -181,9 +176,9 @@ void binary_broadcast_vec_impl(
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(device_cast(input1.data())); auto* xp = as_vec<4>(device_cast(input1.data()));
auto* yp = as_vec4(device_cast(input2.data())); auto* yp = as_vec<4>(device_cast(input2.data()));
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -192,7 +187,7 @@ void binary_broadcast_vec_impl( ...@@ -192,7 +187,7 @@ void binary_broadcast_vec_impl(
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(stream, nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPHX_DEVICE_SHARED vec4<type> buffer[2048 / vec_size]; MIGRAPHX_DEVICE_SHARED vec<type, 4> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{ {
...@@ -205,8 +200,8 @@ void binary_broadcast_vec_impl( ...@@ -205,8 +200,8 @@ void binary_broadcast_vec_impl(
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx]; auto b = bp[bidx];
vec4<type> x = xp[i]; vec<type, 4> x = xp[i];
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
out[j] = f(x[j], b); out[j] = f(x[j], b);
...@@ -270,10 +265,10 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -270,10 +265,10 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(device_cast(inputs.data())...); auto data = pack_vec<4>(device_cast(inputs.data())...);
auto* outp = as_vec4(device_cast(output.data())); auto* outp = as_vec<4>(device_cast(output.data()));
gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec<type, 4> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
...@@ -290,13 +285,11 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments. ...@@ -290,13 +285,11 @@ void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments.
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); std::size_t nelements = result.get_shape().elements();
const auto& output_shape = result.get_shape(); hip_visit_all(result, args...)([&](auto output, auto... inputs) {
visit_all(result, args...)([&](auto output, auto... inputs) { gs_launch(stream, nelements)([=](auto i) {
auto data = pack(device_cast(inputs.data())...); output.data()[i] = f(inputs.data()[i]...);
auto* outp = device_cast(output.data()); });
gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tensor_view.hpp> #include <migraphx/tensor_view.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,6 +65,30 @@ struct hip_array ...@@ -64,6 +65,30 @@ struct hip_array
MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); } MIGRAPHX_DEVICE_CONSTEXPR T* end() { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); } MIGRAPHX_DEVICE_CONSTEXPR const T* end() const { return d + size(); }
MIGRAPHX_DEVICE_CONSTEXPR T dot(const hip_array& x) const
{
T result = 0;
for(std::size_t i = 0; i < N; i++)
result += x[i] * d[i];
return result;
}
MIGRAPHX_DEVICE_CONSTEXPR T product() const
{
T result = 1;
for(std::size_t i = 0; i < N; i++)
result *= d[i];
return result;
}
friend MIGRAPHX_DEVICE_CONSTEXPR hip_array operator*(const hip_array& x, const hip_array& y)
{
hip_array result;
for(std::size_t i = 0;i < N;i++)
result[i] = x[i] * y[i];
return result;
}
}; };
template <class T, std::size_t N> template <class T, std::size_t N>
...@@ -130,12 +155,11 @@ struct hip_shape ...@@ -130,12 +155,11 @@ struct hip_shape
using hip_index = hip_array<std::size_t, N>; using hip_index = hip_array<std::size_t, N>;
hip_array<std::size_t, N> lens = {}; hip_array<std::size_t, N> lens = {};
hip_array<std::size_t, N> strides = {}; hip_array<std::size_t, N> strides = {};
std::size_t elements = 0;
bool standard = false; bool standard = false;
__device__ __host__ hip_shape() = default; __device__ __host__ hip_shape() = default;
hip_shape(const shape& s) : elements(s.elements()), standard(s.standard()) hip_shape(const shape& s) : standard(s.standard())
{ {
assert(s.lens().size() == N); assert(s.lens().size() == N);
assert(s.strides().size() == N); assert(s.strides().size() == N);
...@@ -143,12 +167,14 @@ struct hip_shape ...@@ -143,12 +167,14 @@ struct hip_shape
std::copy(s.strides().begin(), s.strides().end(), strides.begin()); std::copy(s.strides().begin(), s.strides().end(), strides.begin());
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t elements() const
{
return lens.product();
}
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(hip_index x) const
{ {
std::size_t idx = 0; return x.dot(strides);
for(std::size_t i = 0; i < x.size(); i++)
idx += x[i] * strides[i];
return idx;
} }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const MIGRAPHX_DEVICE_CONSTEXPR std::size_t index(std::initializer_list<std::size_t> x) const
...@@ -173,9 +199,10 @@ struct hip_shape ...@@ -173,9 +199,10 @@ struct hip_shape
const std::size_t k = rank - j - 1; const std::size_t k = rank - j - 1;
const std::size_t stride = this->strides[k]; const std::size_t stride = this->strides[k];
const std::size_t len = this->lens[k]; const std::size_t len = this->lens[k];
const std::size_t idx = (i % (s * len)) / s; const std::size_t slen = s * len;
const std::size_t idx = (i % slen) / s;
result += stride * idx; result += stride * idx;
s *= len; s = slen;
} }
return result; return result;
} }
...@@ -197,23 +224,25 @@ struct hip_shape ...@@ -197,23 +224,25 @@ struct hip_shape
template <class T, std::size_t N> template <class T, std::size_t N>
struct hip_tensor_view struct hip_tensor_view
{ {
using value_type = device_type<T>;
__device__ __host__ hip_tensor_view() = default; __device__ __host__ hip_tensor_view() = default;
__device__ __host__ hip_tensor_view(tensor_view<T> x) : d(x.data()), s(x.get_shape()) {} __host__ hip_tensor_view(tensor_view<T> x) : d(device_cast(x.data())), s(x.get_shape()) {}
MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; } MIGRAPHX_DEVICE_CONSTEXPR const hip_shape<N>& get_shape() const { return s; }
MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements; } MIGRAPHX_DEVICE_CONSTEXPR std::size_t size() const { return s.elements(); }
MIGRAPHX_DEVICE_CONSTEXPR T* data() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR value_type* data() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T& operator[](std::size_t i) const { return d[s.index(i)]; } template<class U>
MIGRAPHX_DEVICE_CONSTEXPR value_type& operator[](U i) const { return d[s.index(i)]; }
MIGRAPHX_DEVICE_CONSTEXPR T* begin() const { return d; } MIGRAPHX_DEVICE_CONSTEXPR value_type* begin() const { return d; }
MIGRAPHX_DEVICE_CONSTEXPR T* end() const { return d + size(); } MIGRAPHX_DEVICE_CONSTEXPR value_type* end() const { return d + size(); }
private: private:
T* d = nullptr; value_type* d = nullptr;
hip_shape<N> s{}; hip_shape<N> s{};
}; };
......
...@@ -408,8 +408,8 @@ void fuse_ops::apply(program& p) const ...@@ -408,8 +408,8 @@ void fuse_ops::apply(program& p) const
// clang-format off // clang-format off
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_conv_bias_relu{ctx}, // find_conv_bias_relu{ctx},
find_conv_bias{ctx}, // find_conv_bias{ctx},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment