Commit b5c0f7ef authored by Paul's avatar Paul
Browse files

Use hip_add

parent d0bcc85a
......@@ -5,6 +5,14 @@
namespace migraph {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
namespace detail {
template <class R, class F>
......@@ -19,8 +27,48 @@ struct fix_f
}
};
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, std::size_t... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
}
} // namespace detail
template<std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
/// Implements a fix-point combinator
template <class R, class F>
detail::fix_f<R, F> fix(F f)
......
......@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraph/functional.hpp>
namespace migraph {
struct swallow
{
template <class... Ts>
swallow(Ts&&...)
{
}
};
struct tracer
{
tracer() {}
......
......@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
};
}
inline auto gs_launch(std::size_t n, std::size_t local = 512)
inline auto gs_launch(std::size_t n, std::size_t local = 256)
{
std::size_t groups = 1 + n / local;
std::size_t nglobal = std::min<std::size_t>(512, groups) * local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) {
launch(nglobal, local)([=](auto idx) {
......
......@@ -17,13 +17,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data =
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(),
inputs.get_shape().strides()},
pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) {
data([&](auto&&... ps) {
auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...);
});
......@@ -57,7 +56,10 @@ template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) {
if(all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }))
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
......
......@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h>
#include <migraph/functional.hpp>
namespace migraph {
namespace gpu {
......@@ -53,14 +54,13 @@ template <size_t NDim>
struct hip_tensor_descriptor
{
__device__ __host__ hip_tensor_descriptor() = default;
template <typename T, typename V>
__device__ __host__ hip_tensor_descriptor(const T& lens_ext, const V& strides_ext)
hip_tensor_descriptor(const shape& s)
{
for(size_t i = 0; i < NDim; i++)
lens[i] = lens_ext[i];
for(size_t i = 0; i < NDim; i++)
strides[i] = strides_ext[i];
std::copy(s.lens().begin(), s.lens().end(), lens);
std::copy(s.strides().begin(), s.strides().end(), strides);
}
__device__ __host__ hip_index<NDim> multi(size_t idx) const
{
hip_index<NDim> result{};
......
......@@ -174,6 +174,7 @@ struct hip_add
std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
// check_shapes{inputs, *this}.has(3).standard();
check_shapes{inputs, *this}.has(3);
return inputs.at(0);
}
......@@ -185,6 +186,37 @@ struct hip_add
}
};
struct miopen_add
{
std::string name() const { return "gpu::add"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3).not_broadcasted();
return inputs.at(0);
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
float alpha = 1, beta = 0;
auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(),
miopenTensorOpAdd,
&alpha,
a_desc.get(),
args[0].implicit(),
&alpha,
b_desc.get(),
args[1].implicit(),
&beta,
c_desc.get(),
args[2].implicit());
return args[2];
}
};
struct miopen_gemm
{
gemm op;
......
......@@ -21,14 +21,14 @@ std::vector<pass> target::get_passes(migraph::context& gctx) const
// clang-format off
return
{
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
auto_contiguous{},
simplify_reshapes{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
lowering{ctx},
fuse_ops{},
// fuse_ops{},
dead_code_elimination{},
eliminate_workspace{},
eliminate_contiguous{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment