"src/tl_templates/vscode:/vscode.git/clone" did not exist on "f4a828f6ba004f4d1165e6d46ac8b42e25f736fd"
Commit c4b1102e authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_model_test' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 5fc48e77 31065c7d
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp> #include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp> #include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -76,8 +77,8 @@ struct program ...@@ -76,8 +77,8 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const; std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const;
std::size_t size() const; std::size_t size() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
......
...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector) ...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
} }
template <class T> template <class T>
auto reflectable_impl(rank<1>, T&& x) auto reflectable_impl(rank<1>, const T& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{}); -> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T> template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{}); auto reflectable_impl(rank<0>, const T&) -> decltype(std::false_type{});
template <class T> template <class T>
struct remove_rvalue_reference struct remove_rvalue_reference
...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f) ...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template <class T> template <class T>
auto reflect_tie(T& x) auto reflect_tie(T& x)
{ {
return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })( return reflect(x, [](auto&& y, auto&&...) {
[](auto&&... xs) { return detail::auto_tuple(xs.get()...); }); // cppcheck-suppress UnnecessaryElseStatement
if constexpr(is_reflectable<decltype(y)>{})
{
auto t = reflect_tie(y);
return detail::wrap<decltype(t)>(t);
}
else
{
return detail::wrap<decltype(y)>(y);
}
})([](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
} }
template <class T, class F> template <class T, class F>
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -89,7 +90,10 @@ struct shape ...@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0; std::size_t opt = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f); static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
...@@ -115,6 +119,12 @@ struct shape ...@@ -115,6 +119,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank)
shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -136,6 +146,12 @@ struct shape ...@@ -136,6 +146,12 @@ struct shape
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std::size_t ndim() const;
/*! /*!
* Return the number of elements in the tensor. * Return the number of elements in the tensor.
*/ */
...@@ -221,6 +237,9 @@ struct shape ...@@ -221,6 +237,9 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape
shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector> #include <vector>
...@@ -83,6 +85,20 @@ auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r) ...@@ -83,6 +85,20 @@ auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
os << "}"; os << "}";
} }
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{
char delim = '{';
reflect_each(x, [&](auto&& y, auto name) {
os << delim;
os << name << "=";
stream_write_value_impl(rank<2>{}, os, y);
delim = ',';
});
if(delim == ',')
os << "}";
}
} // namespace detail } // namespace detail
template <class T> template <class T>
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <fstream> #include <fstream>
namespace migraphx { namespace migraphx {
......
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
......
...@@ -30,7 +30,7 @@ namespace onnx { ...@@ -30,7 +30,7 @@ namespace onnx {
void recalc_conv_attributes(value& v, size_t kdims) void recalc_conv_attributes(value& v, size_t kdims)
{ {
if(not(v["padding"].size() == kdims or v["padding"].size() == kdims * 2)) if(v["padding"].size() != kdims and v["padding"].size() != kdims * 2)
{ {
v["padding"].resize(kdims); v["padding"].resize(kdims);
std::fill_n(v["padding"].begin(), kdims, 0); std::fill_n(v["padding"].begin(), kdims, 0);
......
...@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -44,7 +44,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto x_lens = args[0]->get_shape().lens(); auto x_lens = args[0]->get_shape().max_lens();
auto x_type = args[0]->get_shape().type(); auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) { if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
...@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -54,18 +54,19 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1"); MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
} }
if(x_lens.size() == 1) auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2)
{ {
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}}); auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}}); auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto n0 = info.add_broadcastable_binary_op("sub", args[0], args[3]); auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto d0 = info.add_broadcastable_binary_op("add", args[4], eps); auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto d1 = info.add_broadcastable_binary_op("pow", d0, rt); auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", n0, d1); auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]); auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]); return info.add_broadcastable_binary_op("add", r0, args[2]);
} }
else if(x_lens.size() > 2) else if(x_rank > 2)
{ {
// unsqueeze tensors of shape (C) to broadcast correctly // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2); std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
...@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -89,7 +90,7 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
} }
else else
{ {
// num dims either 0 or 2 // rank == 0
MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) + MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
" input tensor, unhandled data format"); " input tensor, unhandled data format");
} }
......
...@@ -96,7 +96,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution> ...@@ -96,7 +96,7 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations"); kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
} }
// TODO: nothing is done with this? // TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{ {
auto s = info.attributes["auto_pad"].s(); auto s = info.attributes["auto_pad"].s();
......
...@@ -59,6 +59,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input ...@@ -59,6 +59,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
bool use_upper) bool use_upper)
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
assert(input_lens.size() >= 3);
std::size_t num_spatial_dims = input_lens.size() - 2; std::size_t num_spatial_dims = input_lens.size() - 2;
padding.resize(2 * num_spatial_dims); padding.resize(2 * num_spatial_dims);
for(std::size_t i = 0; i < num_spatial_dims; i++) for(std::size_t i = 0; i < num_spatial_dims; i++)
......
...@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p, ...@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
...@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
#endif #endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;
if(exec_env.async)
{
ctx.wait_for(exec_env.queue);
}
if(trace_level > 0) if(trace_level > 0)
{ {
...@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -434,49 +440,56 @@ std::vector<argument> program::eval(parameter_map params) const
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
return generic_eval(*this, ret = generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) { with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
auto result = check_context(f); auto result = check_context(f);
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
ctx.finish(); ctx.finish();
double t2 = t.record<milliseconds>(); double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and if(trace_level > 1 and ins->name().front() != '@' and
ins->name() != "load" and not result.empty()) ins->name() != "load" and not result.empty())
{ {
target tgt = make_target(this->impl->target_name); target tgt = make_target(this->impl->target_name);
auto buffer = tgt.copy_from(result); auto buffer = tgt.copy_from(result);
if(trace_level == 2) if(trace_level == 2)
{ {
std::cout << "Output has " std::cout << "Output has "
<< to_string_range(classify_argument(buffer)) << to_string_range(classify_argument(buffer))
<< std::endl; << std::endl;
std::cout << "Output: "; std::cout << "Output: ";
preview_argument(std::cout, buffer); preview_argument(std::cout, buffer);
std::cout << std::endl; std::cout << std::endl;
} }
else else
{ {
std::cout << "Output: " << buffer << std::endl; std::cout << "Output: " << buffer << std::endl;
} }
} }
return result; return result;
})); }));
} }
else else
{ {
return generic_eval(*this, ret = generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) { with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f); return check_context(f);
})); }));
} }
if(exec_env.async)
{
ctx.finish_on(exec_env.queue);
}
return ret;
} }
const int program_file_version = 5; const int program_file_version = 5;
......
...@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
} }
return p.eval(pm); return p.eval(pm);
}) })
.def("run_async",
[](migraphx::program& p,
py::dict params,
std::uintptr_t stream,
std::string stream_name) {
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
py::buffer b = x.second.cast<py::buffer>();
py::buffer_info info = b.request();
pm[key] = migraphx::argument(to_shape(info), info.ptr);
}
migraphx::execution_environment exec_env{
migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
return p.eval(pm, exec_env);
})
.def("sort", &migraphx::program::sort) .def("sort", &migraphx::program::sort)
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{}) .def("__eq__", std::equal_to<migraphx::program>{})
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "batch_norm_inference")
continue;
// Get scale, bias, mean, variance from inputs
auto gamma = ins->inputs()[1]->eval();
auto bias = ins->inputs()[2]->eval();
auto mean = ins->inputs()[3]->eval();
auto variance = ins->inputs()[4]->eval();
if(any_of({gamma, bias, mean, variance}, [](auto arg) { return arg.empty(); }))
continue;
std::vector<std::size_t> lens = ins->inputs()[1]->get_shape().lens();
shape s{ins->get_shape().type(), lens};
// Get epsilon
auto bn_op = any_cast<op::batch_norm_inference>(ins->get_operator());
auto epsilon = bn_op.epsilon;
argument a{s};
argument b{s};
visit_all(gamma, bias, mean, variance, a, b)(
[&](auto gamma2, auto bias2, auto mean2, auto variance2, auto a2, auto b2) {
dfor(a.get_shape().elements())(
[&](std::size_t c) { a2[c] = gamma2[c] / std::sqrt(variance2[c] + epsilon); });
dfor(b.get_shape().elements())([&](std::size_t c) {
b2[c] = bias2[c] - (gamma2[c] * mean2[c] / std::sqrt(variance2[c] + epsilon));
});
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
m.replace_instruction(ins, add);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -46,9 +46,6 @@ ...@@ -46,9 +46,6 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -71,6 +71,19 @@ struct shape_impl ...@@ -71,6 +71,19 @@ struct shape_impl
{ {
} }
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: m_type(t)
{
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]});
}
}
shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {} shape_impl(const std::vector<shape>& subs) : m_type(shape::tuple_type), m_shapes(subs) {}
shape::type_t m_type; shape::type_t m_type;
...@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims) ...@@ -224,6 +237,14 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
{ {
} }
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts)))
{
}
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {} shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
...@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; } ...@@ -244,6 +265,15 @@ const std::vector<std::size_t>& shape::lens() const { return impl->m_lens; }
const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; } const std::vector<std::size_t>& shape::strides() const { return impl->m_strides; }
std::size_t shape::ndim() const
{
if(this->dynamic())
{
return dyn_dims().size();
}
return lens().size();
}
std::size_t shape::elements() const { return impl->elements(); } std::size_t shape::elements() const { return impl->elements(); }
std::size_t shape::bytes() const std::size_t shape::bytes() const
...@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const ...@@ -437,6 +467,16 @@ shape shape::with_type(type_t t) const
return {c}; return {c};
} }
shape shape::to_dynamic() const
{
if(this->dynamic())
{
return *this;
}
std::vector<std::size_t> zeroes(this->ndim(), 0);
return {type(), lens(), lens(), zeroes};
}
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
...@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; ...@@ -464,15 +504,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; } bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
template <class Self, class F>
auto shape::dynamic_dimension::reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{ {
return (x.min == y.min and x.max == y.max and x.opt == y.opt); // don't check opt if both are fixed
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
} }
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y) bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......
...@@ -827,7 +827,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -827,7 +827,7 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return not(dots < 2 and convs < 2); return (dots >= 2 or convs >= 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -933,6 +933,73 @@ struct find_div_const ...@@ -933,6 +933,73 @@ struct find_div_const
} }
}; };
struct find_unit_ops
{
auto matcher() const
{
auto mul_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(1.0f), match::any().bind("x")));
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1e-12), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
return match::any_of(mul_1, div_1, add_0, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
m.replace_instruction(ins, c_in);
}
};
struct find_neg_unit_ops
{
auto matcher() const
{
auto mul_neg_1 = match::name("mul")(
match::either_arg(0, 1)(match::has_value(-1.0f), match::any().bind("x")));
auto div_neg_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(-1.0f)));
auto sub_0 =
match::name("sub")(match::args(match::has_value(0.0f), match::any().bind("x")));
return match::any_of(mul_neg_1, div_neg_1, sub_0);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in);
m.replace_instruction(ins, neg);
}
};
struct find_zero_ops
{
auto matcher() const
{
auto mul_zero = match::name("mul")(
match::either_arg(0, 1)(match::has_value(0.0f).bind("x"), match::any()));
auto div_zero =
match::name("div")(match::args(match::has_value(0.0f).bind("x"), match::any()));
return match::any_of(mul_zero, div_zero);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto zero_ins = r.instructions["x"];
m.replace_instruction(ins, zero_ins);
}
};
struct find_sub_const struct find_sub_const
{ {
auto matcher() const auto matcher() const
...@@ -1149,6 +1216,9 @@ void simplify_algebra::apply(module& m) const ...@@ -1149,6 +1216,9 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{}, find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
......
...@@ -37,12 +37,10 @@ ...@@ -37,12 +37,10 @@
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp> #include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
...@@ -78,8 +76,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -78,8 +76,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_batchnorm{},
dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
......
...@@ -78,17 +78,13 @@ add_library(migraphx_gpu ...@@ -78,17 +78,13 @@ add_library(migraphx_gpu
allocation_model.cpp allocation_model.cpp
argmax.cpp argmax.cpp
argmin.cpp argmin.cpp
batch_norm_inference.cpp
code_object_op.cpp code_object_op.cpp
compile_ops.cpp compile_ops.cpp
compile_gen.cpp compile_gen.cpp
compile_hip.cpp compile_hip.cpp
compile_hip_code_object.cpp compile_hip_code_object.cpp
compiler.cpp compiler.cpp
convolution.cpp
deconvolution.cpp
device_name.cpp device_name.cpp
elu.cpp
fuse_mlir.cpp fuse_mlir.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
...@@ -101,7 +97,6 @@ add_library(migraphx_gpu ...@@ -101,7 +97,6 @@ add_library(migraphx_gpu
logsoftmax.cpp logsoftmax.cpp
loop.cpp loop.cpp
lrn.cpp lrn.cpp
leaky_relu.cpp
mlir.cpp mlir.cpp
multinomial.cpp multinomial.cpp
nonzero.cpp nonzero.cpp
...@@ -111,7 +106,6 @@ add_library(migraphx_gpu ...@@ -111,7 +106,6 @@ add_library(migraphx_gpu
pad.cpp pad.cpp
perfdb.cpp perfdb.cpp
pooling.cpp pooling.cpp
quant_convolution.cpp
reverse.cpp reverse.cpp
rnn_variable_seq_lens.cpp rnn_variable_seq_lens.cpp
rocblas.cpp rocblas.cpp
...@@ -146,16 +140,10 @@ register_migraphx_gpu_ops(hip_ ...@@ -146,16 +140,10 @@ register_migraphx_gpu_ops(hip_
) )
register_migraphx_gpu_ops(miopen_ register_migraphx_gpu_ops(miopen_
abs abs
batch_norm_inference
contiguous contiguous
convolution
deconvolution
elu
int8_conv_pack int8_conv_pack
leaky_relu
lrn lrn
pooling pooling
quant_convolution
) )
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
...@@ -169,6 +157,9 @@ register_op(migraphx_gpu ...@@ -169,6 +157,9 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment