Commit ff3bd8e6 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 32b69ceb c310bc5c
...@@ -226,6 +226,11 @@ struct id ...@@ -226,6 +226,11 @@ struct id
} }
}; };
template <class... Ts>
void nop(Ts&&...)
{
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -66,6 +66,8 @@ struct literal : raw_data<literal> ...@@ -66,6 +66,8 @@ struct literal : raw_data<literal>
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const { return this->m_shape; }
std::vector<literal> get_sub_objects() const { return {}; }
/// Convert the data to an argument /// Convert the data to an argument
argument get_argument() const argument get_argument() const
{ {
......
...@@ -33,7 +33,7 @@ struct as_shape ...@@ -33,7 +33,7 @@ struct as_shape
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args.front().reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -64,7 +64,7 @@ struct broadcast ...@@ -64,7 +64,7 @@ struct broadcast
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -48,7 +48,7 @@ struct flatten ...@@ -48,7 +48,7 @@ struct flatten
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -19,11 +19,8 @@ struct identity ...@@ -19,11 +19,8 @@ struct identity
{ {
std::string name() const { return "identity"; } std::string name() const { return "identity"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); } shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape, std::vector<argument> args) const { return args[0]; }
{
return {std::move(output_shape), std::move(args.at(0).data)};
}
bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -34,7 +34,7 @@ struct load ...@@ -34,7 +34,7 @@ struct load
{ {
if((offset + s.bytes()) > args[0].get_shape().bytes()) if((offset + s.bytes()) > args[0].get_shape().bytes())
MIGRAPHX_THROW("Load access is out of bounds"); MIGRAPHX_THROW("Load access is out of bounds");
return {s, args[0].data() + offset}; return argument::load(s, args[0].data() + offset);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -66,7 +66,7 @@ struct multibroadcast ...@@ -66,7 +66,7 @@ struct multibroadcast
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct prefix_scan_op : op_name<Derived>
{
int64_t axis;
bool exclusive = false;
bool reverse = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.axis, "axis"), f(self.exclusive, "exclusive"), f(self.reverse, "reverse"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return inputs.at(0);
}
argument compute(const shape&, std::vector<argument> args) const
{
argument result = args[0];
auto s = result.get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
result.visit([&](auto output) {
using type = decltype(output);
par_for(batch.elements(), [&](auto i) {
auto* start = output.data() + batch.index(i);
type x{slice, start};
if(reverse)
{
if(exclusive)
{
std::copy(++x.begin(), x.end(), x.begin());
x.back() = 0;
}
std::partial_sum(std::make_reverse_iterator(x.end()),
std::make_reverse_iterator(x.begin()),
std::make_reverse_iterator(x.end()),
self.op());
}
else
{
if(exclusive)
{
std::copy_backward(x.begin(), --x.end(), x.end());
x.front() = 0;
}
std::partial_sum(x.begin(), x.end(), x.begin(), self.op());
}
});
});
return result;
}
auto init() const {}
prefix_scan_op() : axis(0) {}
prefix_scan_op(int64_t ax) : axis(ax) {}
prefix_scan_op(int64_t ax, bool excl) : axis(ax), exclusive(excl) {}
prefix_scan_op(int64_t ax, bool excl, bool rev) : axis(ax), exclusive(excl), reverse(rev) {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCAN_INCLUSIVE_SUM_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/prefix_scan_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct prefix_scan_sum : prefix_scan_op<prefix_scan_sum>
{
prefix_scan_sum() {}
prefix_scan_sum(int64_t ax) : prefix_scan_op(ax) {}
prefix_scan_sum(int64_t ax, bool excl) : prefix_scan_op(ax, excl) {}
prefix_scan_sum(int64_t ax, bool excl, bool rev) : prefix_scan_op(ax, excl, rev) {}
auto op() const
{
return [](auto x, auto y) { return x + y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -68,7 +68,7 @@ struct reshape ...@@ -68,7 +68,7 @@ struct reshape
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
......
...@@ -37,7 +37,7 @@ struct scalar ...@@ -37,7 +37,7 @@ struct scalar
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.at(0).data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -75,7 +75,7 @@ struct squeeze ...@@ -75,7 +75,7 @@ struct squeeze
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -61,7 +61,7 @@ struct transpose ...@@ -61,7 +61,7 @@ struct transpose
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -68,7 +68,7 @@ struct unsqueeze ...@@ -68,7 +68,7 @@ struct unsqueeze
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
bool is_borrowed() const { return true; } bool is_borrowed() const { return true; }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
......
...@@ -58,10 +58,11 @@ ...@@ -58,10 +58,11 @@
#include <migraphx/op/outline.hpp> #include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/op/prelu.hpp> #include <migraphx/op/prelu.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/pow.hpp>
#include <migraphx/op/recip.hpp> #include <migraphx/op/recip.hpp>
#include <migraphx/op/reduce_max.hpp> #include <migraphx/op/reduce_max.hpp>
#include <migraphx/op/reduce_mean.hpp> #include <migraphx/op/reduce_mean.hpp>
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
struct function_output_iterator
{
F f;
using self = function_output_iterator;
using difference_type = void;
using reference = void;
using value_type = void;
using pointer = void;
using iterator_category = std::output_iterator_tag;
struct output_proxy
{
template <class T>
output_proxy& operator=(const T& value)
{
assert(f);
(*f)(value);
return *this;
}
F* f;
};
output_proxy operator*() { return output_proxy{&f}; }
self& operator++() { return *this; }
self& operator++(int) { return *this; } // NOLINT
};
template <class F>
function_output_iterator<F> make_function_output_iterator(F f)
{
return {std::move(f)};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -23,8 +23,10 @@ struct pass ...@@ -23,8 +23,10 @@ struct pass
{ {
/// A unique name used to identify the pass /// A unique name used to identify the pass
std::string name() const; std::string name() const;
/// Run the pass on the module
void apply(module& m) const;
/// Run the pass on the program /// Run the pass on the program
void apply(module& p) const; void apply(program& p) const;
}; };
#else #else
...@@ -35,7 +37,8 @@ struct pass ...@@ -35,7 +37,8 @@ struct pass
* struct pass * struct pass
* { * {
* std::string name() const; * std::string name() const;
* void apply(module & p) const; * void apply(module & m) const;
* void apply(program & p) const;
* }; * };
* *
*/ */
...@@ -109,7 +112,13 @@ struct pass ...@@ -109,7 +112,13 @@ struct pass
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(module& p) const void apply(module& m) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(m);
}
void apply(program& p) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(p); (*this).private_detail_te_get_handle().apply(p);
...@@ -128,10 +137,37 @@ struct pass ...@@ -128,10 +137,37 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(module& p) const = 0; virtual void apply(module& m) const = 0;
virtual void apply(program& p) const = 0;
}; };
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, module& m)
-> decltype(private_detail_te_self.apply(m))
{
private_detail_te_self.apply(m);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, module& m)
{
migraphx::nop(private_detail_te_self, m);
}
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, program& p)
-> decltype(private_detail_te_self.apply(p))
{
private_detail_te_self.apply(p);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, program& p)
{
migraphx::nop(private_detail_te_self, p);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -162,7 +198,17 @@ struct pass ...@@ -162,7 +198,17 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(module& p) const override { private_detail_te_value.apply(p); } void apply(module& m) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, m);
}
void apply(program& p) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, p);
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace = tracer{}); void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -101,6 +101,9 @@ struct program ...@@ -101,6 +101,9 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
void remove_module(const std::string& name);
void remove_unused_modules();
private: private:
void assign(const program& p); void assign(const program& p);
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment