Commit 5a14c0bf authored by umangyadav's avatar umangyadav
Browse files

Merge branch 'develop' into workspace_size

parents cb01e280 5fa42993
...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&) ...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr)
{
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -78,6 +87,10 @@ struct context ...@@ -78,6 +87,10 @@ struct context
void from_value(const value& v); void from_value(const value& v);
// (optional) // (optional)
any_ptr get_queue(); any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
// //
void finish() const; void finish() const;
}; };
...@@ -165,6 +178,18 @@ struct context ...@@ -165,6 +178,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue(); return (*this).private_detail_te_get_handle().get_queue();
} }
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -187,6 +212,8 @@ struct context ...@@ -187,6 +212,8 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0; virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -231,6 +258,33 @@ struct context ...@@ -231,6 +258,33 @@ struct context
return get_queue_context(private_detail_te_self); return get_queue_context(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
{
private_detail_te_self.finish_on(queue);
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -248,7 +302,7 @@ struct context ...@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(value)
{ {
} }
...@@ -277,6 +331,18 @@ struct context ...@@ -277,6 +331,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value); return private_detail_te_default_get_queue(char(0), private_detail_te_value);
} }
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
...@@ -21,22 +21,21 @@ ...@@ -21,22 +21,21 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_ADD_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_RTGLIB_ADD_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/gpu/oper.hpp> #include <migraphx/any_ptr.hpp>
#include <migraphx/gpu/device/add.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_add : binary_device<hip_add, device::add> struct execution_environment
{ {
any_ptr queue = any_ptr{};
bool async = false;
}; };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
...@@ -50,8 +50,8 @@ struct layernorm_matcher ...@@ -50,8 +50,8 @@ struct layernorm_matcher
{ {
return f("div")(arg(0)(x_minus_mean()), return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")( arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f)))))))); f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <cmath>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct batch_norm_inference
{
float epsilon = 1.0e-6f;
float momentum = 0.9f;
std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(
f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(5);
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape();
return inputs.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -33,11 +33,11 @@ namespace migraphx { ...@@ -33,11 +33,11 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// Padding mode is default_ for fixed shape padding.
// same_lower and same_upper used for dynamic padding.
enum padding_mode_t enum padding_mode_t
{ {
default_, // NOLINT default_, // NOLINT
same,
valid,
same_lower, same_lower,
same_upper same_upper
}; };
......
...@@ -43,7 +43,6 @@ struct convolution ...@@ -43,7 +43,6 @@ struct convolution
int group = 1; int group = 1;
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,8 +51,7 @@ struct convolution ...@@ -52,8 +51,7 @@ struct convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.group, "group"), f(self.group, "group"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
std::string name() const { return "convolution"; } std::string name() const { return "convolution"; }
...@@ -93,13 +91,6 @@ struct convolution ...@@ -93,13 +91,6 @@ struct convolution
x_shape.lens().at(1) != (w_shape.lens().at(1) * group)) x_shape.lens().at(1) != (w_shape.lens().at(1) * group))
MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers"); MIGRAPHX_THROW("CONVOLUTION: mismatched channel numbers");
std::vector<op::padding_mode_t> dyn_pad_modes = {op::padding_mode_t::same_upper,
op::padding_mode_t::same_lower};
if(use_dynamic_same_auto_pad and not contains(dyn_pad_modes, padding_mode))
{
MIGRAPHX_THROW("CONVOLUTION: use_dynamic_same_auto_pad set with invalid padding mode");
}
if(x_shape.dynamic() or w_shape.dynamic()) if(x_shape.dynamic() or w_shape.dynamic())
{ {
return dynamic_compute_shape(x_shape, w_shape); return dynamic_compute_shape(x_shape, w_shape);
...@@ -161,7 +152,7 @@ struct convolution ...@@ -161,7 +152,7 @@ struct convolution
dynamic_shape_push_back(w_shape); dynamic_shape_push_back(w_shape);
const size_t num_spatial_dims = x_shape.max_lens().size() - 2; const size_t num_spatial_dims = x_shape.max_lens().size() - 2;
if(use_dynamic_same_auto_pad) if(padding_mode != default_)
{ {
for(std::size_t i = 0; i < num_spatial_dims; ++i) for(std::size_t i = 0; i < num_spatial_dims; ++i)
{ {
......
...@@ -40,7 +40,6 @@ struct fmod : binary<fmod> ...@@ -40,7 +40,6 @@ struct fmod : binary<fmod>
a["commutative"] = false; a["commutative"] = false;
return a; return a;
} }
std::string point_function() const { return "fmod"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::fmod(x, y); }; return [](auto x, auto y) { return std::fmod(x, y); };
......
...@@ -38,9 +38,9 @@ struct mod : binary<mod> ...@@ -38,9 +38,9 @@ struct mod : binary<mod>
{ {
auto a = base_attributes(); auto a = base_attributes();
a["commutative"] = false; a["commutative"] = false;
a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})";
return a; return a;
} }
std::string point_function() const { return "mod"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); }; return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
......
...@@ -43,7 +43,6 @@ struct quant_convolution ...@@ -43,7 +43,6 @@ struct quant_convolution
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,8 +51,7 @@ struct quant_convolution ...@@ -52,8 +51,7 @@ struct quant_convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.group, "group"), f(self.group, "group"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
value attributes() const value attributes() const
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp> #include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
......
...@@ -24,9 +24,10 @@ ...@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <migraphx/config.hpp>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx, ...@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
/*! /*!
* Calculate the padding for auto_padding. Used for dynamic shapes * Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time. * where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...} * \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/ */
std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens, std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
std::vector<std::size_t> k_lens, const std::vector<std::size_t>& wei_lens,
std::vector<std::size_t> strides, const std::vector<std::size_t>& strides,
std::vector<std::size_t> dilations, const std::vector<std::size_t>& dilations,
bool use_upper = true); bool use_upper);
// Used for dynamic auto padding of convolution operators since padding needs to be computed at
// evaulation time.
shape compute_padded_shape(const shape& input,
const shape& weights,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp> #include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -76,8 +77,8 @@ struct program ...@@ -76,8 +77,8 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const; std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const;
std::size_t size() const; std::size_t size() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
......
...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector) ...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
} }
template <class T> template <class T>
auto reflectable_impl(rank<1>, T&& x) auto reflectable_impl(rank<1>, const T& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{}); -> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T> template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{}); auto reflectable_impl(rank<0>, const T&) -> decltype(std::false_type{});
template <class T> template <class T>
struct remove_rvalue_reference struct remove_rvalue_reference
...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f) ...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template <class T> template <class T>
auto reflect_tie(T& x) auto reflect_tie(T& x)
{ {
return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })( return reflect(x, [](auto&& y, auto&&...) {
[](auto&&... xs) { return detail::auto_tuple(xs.get()...); }); // cppcheck-suppress UnnecessaryElseStatement
if constexpr(is_reflectable<decltype(y)>{})
{
auto t = reflect_tie(y);
return detail::wrap<decltype(t)>(t);
}
else
{
return detail::wrap<decltype(y)>(y);
}
})([](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
} }
template <class T, class F> template <class T, class F>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite batchnorm to a multiply and add.
*/
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -26,8 +26,11 @@ ...@@ -26,8 +26,11 @@
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -59,10 +62,22 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -59,10 +62,22 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail { namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x) { os << x; } template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(os << x, void())
{
os << x;
}
template <class T>
void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
{
os << "{";
os << stream_range(r);
os << "}";
}
template <class Range> template <class Range>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void()) -> decltype(r.begin(), r.end(), void())
{ {
os << "{"; os << "{";
...@@ -70,17 +85,26 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) ...@@ -70,17 +85,26 @@ auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r)
os << "}"; os << "}";
} }
template <class T> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{ {
os << x; char delim = '{';
reflect_each(x, [&](auto&& y, auto name) {
os << delim;
os << name << "=";
stream_write_value_impl(rank<2>{}, os, y);
delim = ',';
});
if(delim == ',')
os << "}";
} }
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<2>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -184,6 +184,12 @@ struct value ...@@ -184,6 +184,12 @@ struct value
{ {
} }
explicit binary(std::size_t s) : base(s) {} explicit binary(std::size_t s) : base(s) {}
friend std::ostream& operator<<(std::ostream& os, const binary& obj)
{
os << "{binary_object: " << obj.size() << "}";
return os;
}
}; };
value() = default; value() = default;
......
...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds ...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst) instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{ {
this->move_instruction(src, dst);
for(auto ins : src->inputs()) for(auto ins : src->inputs())
this->move_instruction(ins, src); {
if(not contains(this->impl->instructions, ins))
continue;
this->move_instructions(ins, dst);
}
this->move_instruction(src, dst);
return src; return src;
} }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const std::vector<instruction_ref> args) const
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
if(contains(info.attributes, "momentum")) auto x_lens = args[0]->get_shape().lens();
auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
return a->get_shape().lens().size() != 1;
}))
{
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
}
auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2)
{
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]);
}
else if(x_rank > 2)
{ {
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>(); // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
auto bias_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]);
auto mean_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
if(contains(info.attributes, "spatial")) else
{ {
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0) // rank == 0
? op::batch_norm_inference::spatial MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
: op::batch_norm_inference::per_activation; " input tensor, unhandled data format");
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
} }
}; };
......
...@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution> ...@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = is_same_upper values["padding_mode"] = is_same_upper
? to_value(op::padding_mode_t::same_upper) ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower); : to_value(op::padding_mode_t::same_lower);
values["use_dynamic_same_auto_pad"] = true;
} }
else else
{ {
values["padding_mode"] = to_value(op::padding_mode_t::same);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths // kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto weight_lens = weights->get_shape().max_lens(); auto weight_lens = weights->get_shape().max_lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end()); std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
......
...@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes( check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations"); kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
} }
// TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = info.attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
...@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
if(s.find("SAME") != std::string::npos) if(s.find("SAME") != std::string::npos)
{ {
values["padding_mode"] = to_value(op::padding_mode_t::same); bool is_same_upper = (s.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment