Commit b5c0f7ef authored by Paul's avatar Paul
Browse files

Use hip_add

parent d0bcc85a
...@@ -5,6 +5,14 @@ ...@@ -5,6 +5,14 @@
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -19,8 +27,48 @@ struct fix_f ...@@ -19,8 +27,48 @@ struct fix_f
} }
}; };
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, std::size_t... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
}
} // namespace detail } // namespace detail
template<std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
......
...@@ -2,17 +2,10 @@ ...@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraph/functional.hpp>
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
swallow(Ts&&...)
{
}
};
struct tracer struct tracer
{ {
tracer() {} tracer() {}
......
...@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local) ...@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
}; };
} }
inline auto gs_launch(std::size_t n, std::size_t local = 512) inline auto gs_launch(std::size_t n, std::size_t local = 256)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = 1 + n / local;
std::size_t nglobal = std::min<std::size_t>(512, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(nglobal, local)([=](auto idx) { launch(nglobal, local)([=](auto idx) {
......
...@@ -17,13 +17,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -17,13 +17,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = auto data =
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
inputs.get_shape().strides()},
inputs.data())...); inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides()); hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) { data([&](auto&&... ps) {
auto outidx = out_desc.multi(i); auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...); outp[i] = f(ps.second[ps.first.linear(outidx)]...);
}); });
...@@ -57,7 +56,10 @@ template <class... Arguments> ...@@ -57,7 +56,10 @@ template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(argument result, Arguments... args)
{ {
return [=](auto f) { return [=](auto f) {
if(all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); })) bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard(result, args...)(f); nary_standard(result, args...)(f);
else else
nary_nonstandard(result, args...)(f); nary_nonstandard(result, args...)(f);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraph/functional.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
...@@ -53,14 +54,13 @@ template <size_t NDim> ...@@ -53,14 +54,13 @@ template <size_t NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
__device__ __host__ hip_tensor_descriptor() = default; __device__ __host__ hip_tensor_descriptor() = default;
template <typename T, typename V>
__device__ __host__ hip_tensor_descriptor(const T& lens_ext, const V& strides_ext) hip_tensor_descriptor(const shape& s)
{ {
for(size_t i = 0; i < NDim; i++) std::copy(s.lens().begin(), s.lens().end(), lens);
lens[i] = lens_ext[i]; std::copy(s.strides().begin(), s.strides().end(), strides);
for(size_t i = 0; i < NDim; i++)
strides[i] = strides_ext[i];
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_index<NDim> multi(size_t idx) const
{ {
hip_index<NDim> result{}; hip_index<NDim> result{};
......
...@@ -174,6 +174,7 @@ struct hip_add ...@@ -174,6 +174,7 @@ struct hip_add
std::string name() const { return "gpu::add"; } std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return inputs.at(0); return inputs.at(0);
} }
...@@ -185,6 +186,37 @@ struct hip_add ...@@ -185,6 +186,37 @@ struct hip_add
} }
}; };
struct miopen_add
{
std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3).not_broadcasted();
return inputs.at(0);
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2];
}
};
struct miopen_gemm struct miopen_gemm
{ {
gemm op; gemm op;
......
...@@ -21,14 +21,14 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const ...@@ -21,14 +21,14 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
// clang-format off // clang-format off
return return
{ {
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
lowering{ctx}, lowering{ctx},
fuse_ops{}, // fuse_ops{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_workspace{}, eliminate_workspace{},
eliminate_contiguous{}, eliminate_contiguous{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment