Commit 31065c7d authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_squeeze' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents 6bec381f 6acbd4e4
......@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -36,28 +36,64 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "momentum"))
auto x_lens = args[0]->get_shape().max_lens();
auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
return a->get_shape().lens().size() != 1;
}))
{
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
}
auto x_rank = x_lens.size();
if(x_rank == 1 or x_rank == 2)
{
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto numer = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto var_eps = info.add_broadcastable_binary_op("add", args[4], eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]);
}
else if(x_rank > 2)
{
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>();
// unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
auto bias_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]);
auto mean_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
}
if(contains(info.attributes, "spatial"))
else
{
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
? op::batch_norm_inference::spatial
: op::batch_norm_inference::per_activation;
// rank == 0
MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
" input tensor, unhandled data format");
}
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
}
};
......
......@@ -38,7 +38,7 @@ struct parse_cast : op_parser<parse_cast>
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
if(!contains(info.attributes, "to"))
if(not contains(info.attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
......
......@@ -93,7 +93,7 @@ struct parse_constant_fill : op_parser<parse_constant_fill>
}
else if(input_as_shape == 0)
{
if(!contains(info.attributes, "shape"))
if(not contains(info.attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
......
......@@ -125,11 +125,9 @@ struct parse_convolution : op_parser<parse_convolution>
values["padding_mode"] = is_same_upper
? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
values["use_dynamic_same_auto_pad"] = true;
}
else
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
// kernel shape will be fixed, so max_lens() == min_len() for kernel lengths
auto weight_lens = weights->get_shape().max_lens();
std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
......
......@@ -95,6 +95,8 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
check_attr_sizes(
kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
}
// TODO: auto padding needs to be implemented for this parser and operator
if(contains(info.attributes, "auto_pad"))
{
auto s = info.attributes["auto_pad"].s();
......@@ -106,7 +108,9 @@ struct parse_deconvolution : op_parser<parse_deconvolution>
if(s.find("SAME") != std::string::npos)
{
values["padding_mode"] = to_value(op::padding_mode_t::same);
bool is_same_upper = (s.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
}
......
......@@ -94,7 +94,7 @@ struct parse_gemm : op_parser<parse_gemm>
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
{
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
......
......@@ -58,7 +58,6 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Log", "log"},
{"LRN", "lrn"},
{"Neg", "neg"},
{"NonMaxSuppression", "nonmaxsuppression"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
......@@ -75,7 +74,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const
{
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name);
return contains({"flatten", "gather", "scatter"}, op_name);
}
instruction_ref parse(const op_desc& opd,
......
......@@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for LpNormalization ONNX operator.
// Parser for LpNormalization ONNX operator.
/*!
Normalizes a tensor by the L1 or L2 norms along a given axis.
Norms that evaluate to 0 are changed to 1 to prevent division by zero.
......
......@@ -67,7 +67,8 @@ struct parse_matmul : op_parser<parse_matmul>
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
if(not std::equal(
l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
......
......@@ -40,9 +40,9 @@ struct parse_mod : op_parser<parse_mod>
std::vector<instruction_ref> args) const
{
std::string mod = "mod";
if(is_type_float(args[0]->get_shape().type()) || is_type_float(args[1]->get_shape().type()))
if(is_type_float(args[0]->get_shape().type()) or is_type_float(args[1]->get_shape().type()))
{
if(!contains(info.attributes, "fmod"))
if(not contains(info.attributes, "fmod"))
{
MIGRAPHX_THROW("Mod operator with float args and fmod=0 invalid");
}
......
......@@ -21,22 +21,29 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ACOS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ACOS_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/acos.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace onnx {
struct hip_acos : unary_device<hip_acos, device::acos>
struct parse_nonmaxsuppression : op_parser<parse_nonmaxsuppression>
{
std::vector<op_desc> operators() const { return {{"NonMaxSuppression", "nonmaxsuppression"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& parser,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto op = parser.load(opd.op_name, info);
op.from_value({{"use_dyn_output", parser.use_dyn_output}});
return info.add_instruction(op, args);
}
};
} // namespace gpu
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -37,7 +37,7 @@ static std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
std::vector<std::size_t> indices;
for(std::size_t i = 0; i < data.size(); ++i)
{
if(!float_equal(data[i], 0))
if(not float_equal(data[i], 0))
indices.push_back(i);
}
......
......@@ -160,7 +160,7 @@ struct parse_pad : op_parser<parse_pad>
if(args.size() == 3)
{
auto val_ins = args.at(2);
if(!val_ins->can_eval())
if(not val_ins->can_eval())
{
MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
}
......
......@@ -157,7 +157,7 @@ struct parse_pooling : op_parser<parse_pooling>
std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start);
if(!slice_start.empty())
if(not slice_start.empty())
{
// calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
......@@ -180,7 +180,7 @@ struct parse_pooling : op_parser<parse_pooling>
op.from_value(values);
auto l1 = info.add_instruction(op, l0);
if(!slice_start.empty())
if(not slice_start.empty())
{
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2);
......
......@@ -46,7 +46,7 @@ auto compute_type(shape::type_t t1, shape::type_t t2)
int it1 = t1;
int it2 = t2;
if(!contains(op_order, it1) or !contains(op_order, it2))
if(not contains(op_order, it1) or not contains(op_order, it2))
{
MIGRAPHX_THROW("PARSE_POW: Input data type not supported!");
}
......
......@@ -56,7 +56,7 @@ const auto& get_nearest_op(const std::string& mode)
return static_cast<std::size_t>(std::ceil((val)));
}}};
if(!contains(nearest_ops, mode))
if(not contains(nearest_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!");
}
......@@ -86,7 +86,7 @@ const auto& get_original_idx_op(const std::string& mode)
return (idx + 0.5) / scale;
}}};
if(!contains(idx_ops, mode))
if(not contains(idx_ops, mode))
{
MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode + " not supported!");
}
......
......@@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
// Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
......
......@@ -29,7 +29,7 @@ inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
if(not enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
......
......@@ -42,7 +42,7 @@ void memory_coloring_impl::run()
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(!alloc_queue.empty())
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
......@@ -72,7 +72,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn];
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
......@@ -96,7 +96,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
}
std::size_t offset = 0;
while(!conflict_queue.empty())
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
......@@ -149,7 +149,7 @@ void memory_coloring_impl::build()
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) || is_lit)
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
......@@ -157,12 +157,12 @@ void memory_coloring_impl::build()
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(!is_lit || unify_literals)
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(!is_param(iter) && !is_outline(iter) && !is_check_context(iter))
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
......@@ -179,7 +179,7 @@ void memory_coloring_impl::build()
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) || is_outline(arg))
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
......@@ -235,7 +235,7 @@ void memory_coloring_impl::rewrite()
if(interval->get_begin() == invalid_offset)
continue;
if(!unify_literals && interval->is_literal)
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
......@@ -267,12 +267,12 @@ void memory_coloring_impl::verify()
{
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
live_range& segment = interval.segment;
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(!interval.is_live_on_entry)
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......@@ -284,13 +284,13 @@ void memory_coloring_impl::verify()
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
std::set<int>& vn_set = conflict_table[vn];
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(!is_disjoin(*range, segment))
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
......@@ -319,8 +319,8 @@ void memory_coloring_impl::dump_intervals()
{
std::cout << " segment:" << i;
std::cout << " =>";
std::set<int>& table = conflict_table[i];
for(auto& iter : table)
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
......@@ -357,7 +357,7 @@ void live_interval::dump()
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(auto& iter : use_points)
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
......
......@@ -125,11 +125,11 @@ struct memory_coloring_impl
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) || (range2.size == 0))
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) || (end2 < range1.offset));
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment