Commit a045fb19 authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'develop' into ck-flash-attn

parents 135eb63e 434a06cf
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt ...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
{ {
write_buffer(filename, save_buffer(p, options)); write_buffer(filename, save_buffer(p, options));
} }
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void print_miopen_warning(const program& p)
{
auto mods = p.get_modules();
if(std::any_of(mods.begin(), mods.end(), [](const auto* m) {
return std::any_of(m->begin(), m->end(), [](const instruction& i) {
return i.name() == "gpu::miopen_fusion";
});
}))
{
std::cout << "[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<< std::endl;
;
}
}
std::vector<char> save_buffer(const program& p, const file_options& options) std::vector<char> save_buffer(const program& p, const file_options& options)
{ {
value v = p.to_value(); value v = p.to_value();
print_miopen_warning(p);
std::vector<char> buffer; std::vector<char> buffer;
if(options.format == "msgpack") if(options.format == "msgpack")
{ {
......
...@@ -25,6 +25,33 @@ ...@@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <msgpack.hpp> #include <msgpack.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Leave an extra byte for error checking
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;
template <class Range>
std::size_t msgpack_chunk_size(const Range& r)
{
return 1 + (r.size() - 1) / msgpack_size_limit;
}
template <class Iterator, class F>
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
{
while(std::distance(start, last) > msgpack_size_limit)
{
auto next = std::next(start, msgpack_size_limit);
f(start, next);
start = next;
}
f(start, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace msgpack { namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break; break;
} }
case msgpack::type::BIN: { case msgpack::type::BIN: {
// For backwards compatibility
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: { case msgpack::type::ARRAY: {
if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
{
auto bin = migraphx::value::binary{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) {
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
});
v = bin;
}
else
{
migraphx::value r = migraphx::value::array{}; migraphx::value r = migraphx::value::array{};
std::for_each( std::for_each(
o.via.array.ptr, o.via.array.ptr,
o.via.array.ptr + o.via.array.size, o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); }); [&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r; v = r;
}
break; break;
} }
case msgpack::type::MAP: { case msgpack::type::MAP: {
...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
const auto* data = reinterpret_cast<const char*>(x.data()); const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size(); auto size = x.size();
o.pack_bin(size); o.pack_array(migraphx::msgpack_chunk_size(x));
o.pack_bin_body(data, size); migraphx::msgpack_chunk_for_each(
data, data + size, [&](const char* start, const char* last) {
o.pack_bin(last - start);
o.pack_bin_body(start, last - start);
});
return o; return o;
} }
}; };
...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o.pack_array(0); o.pack_array(0);
return; return;
} }
if(v.size() > migraphx::msgpack_size_limit)
MIGRAPHX_THROW("Size is too large for msgpack");
if(not v.front().get_key().empty()) if(not v.front().get_key().empty())
{ {
o.pack_map(v.size()); o.pack_map(v.size());
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -191,15 +191,21 @@ bool normalize_attributes(operation& op, const shape& input_shape) ...@@ -191,15 +191,21 @@ bool normalize_attributes(operation& op, const shape& input_shape)
auto attrs = op.attributes(); auto attrs = op.attributes();
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding")) if(attrs.contains("normalize_padding"))
{
bool use_auto_padding =
(val.contains("padding_mode") and
(val.at("padding_mode").to<int>() != migraphx::op::padding_mode_t::default_));
if(not use_auto_padding)
{ {
auto padding = val.at(attrs.at("normalize_padding").to<std::string>()); auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size(); auto padding_size = padding.size();
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (input_shape.ndim() - padding_start)) if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true; tuned = true;
else if(padding_size != (input_shape.ndim() - padding_start)) else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); {
MIGRAPHX_THROW("normalize_attributes: inconsistent padding vector size ");
}
else else
{ {
auto result = tune_pad_attribute(padding); auto result = tune_pad_attribute(padding);
...@@ -208,6 +214,7 @@ bool normalize_attributes(operation& op, const shape& input_shape) ...@@ -208,6 +214,7 @@ bool normalize_attributes(operation& op, const shape& input_shape)
tuned = true; tuned = true;
} }
} }
}
if(not attrs.contains("normalize_axes")) if(not attrs.contains("normalize_axes"))
{ {
return tuned; return tuned;
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -151,26 +151,6 @@ struct parse_pooling : op_parser<parse_pooling>
kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings"); kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
} }
if(contains(info.attributes, "auto_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
}
else
{
values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
{1, 1},
in_shape.lens(),
paddings);
}
}
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
{ {
paddings.resize(kdims * 2); paddings.resize(kdims * 2);
...@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -192,6 +172,36 @@ struct parse_pooling : op_parser<parse_pooling>
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
// TODO: add parsing for dilations
if(contains(info.attributes, "auto_pad") and
to_upper(info.attributes["auto_pad"].s()) != "NOTSET")
{
auto auto_pad = to_upper(info.attributes["auto_pad"].s());
// don't use the given padding sizes, if any
// values["padding"].clear();
if(in_shape.dynamic())
{
// set padding_mode to trigger auto padding at runtime
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
values["padding_mode"] = is_same_upper ? to_value(op::padding_mode_t::same_upper)
: to_value(op::padding_mode_t::same_lower);
}
else
{
// Calculate auto padding
// dilations (argument 4) not supported; default to all 1's
cal_auto_padding_size(info,
values,
values["lengths"].to_vector<std::size_t>(),
std::vector<size_t>(in_shape.ndim() - 2, 1),
in_shape.lens(),
paddings);
values["padding"] = paddings;
// default padding_mode indicates that padding sizes are not calculated dynamically
values["padding_mode"] = migraphx::op::padding_mode_t::default_;
}
}
std::vector<int64_t> slice_start; std::vector<int64_t> slice_start;
std::vector<int64_t> slice_end; std::vector<int64_t> slice_end;
tune_padding_size(values, paddings, count_include_pad, slice_start); tune_padding_size(values, paddings, count_include_pad, slice_start);
...@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -208,8 +218,9 @@ struct parse_pooling : op_parser<parse_pooling>
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
op::pad pad{orig_padding, 0.0f}; op::pad pad{orig_padding, 0.0f};
shape padded_shape = pad.compute_shape({l0->get_shape()}); shape padded_shape = pad.compute_shape({l0->get_shape()});
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// make an op just to get its output shape
auto out_lens = make_op("pooling", values).compute_shape({padded_shape}).lens();
// compute slice_end information // compute slice_end information
slice_end.resize(slice_start.size()); slice_end.resize(slice_start.size());
std::transform(out_lens.begin() + 2, std::transform(out_lens.begin() + 2,
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static std::vector<int> static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind, calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim, int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims, std::vector<std::vector<std::size_t>> vec_dims,
const shape& in_s) const shape& in_s)
{ {
if(i_dim == vvv_ind.size()) if(i_dim == vvv_ind.size())
{ {
std::vector<int> vec_ind; std::vector<int> vec_ind(vec_dims.size());
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx)); return static_cast<int>(in_s.index(idx));
}); });
return vec_ind; return vec_ind;
} }
const auto& vv_ind = vvv_ind[i_dim]; const auto& vv_lo = vvv_ind[i_dim][0];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1; std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
}); });
} }
const auto& vv_hi = vv_ind.at(1); const auto& vv_hi = vvv_ind[i_dim][1];
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size())
{ {
std::transform(vv_hi.begin(), std::transform(vv_hi.begin(),
vv_hi.end(), vv_hi.end(),
...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return dim; return dim;
}); });
} }
vec_dims.clear();
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s); return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s);
} }
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!"); "PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_" + opd.op_name + "PARSE_" + opd.op_name +
": dynamic input scale is not supported!"); ": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize> ...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx // map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode); auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx; std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val); in_idx[ii] = nearest_op(in_lens[ii], idx_val);
} }
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
}); });
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize> ...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize>
// get the number of dimensions // get the number of dimensions
std::size_t n_dim = out_lens.size(); std::size_t n_dim = out_lens.size();
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements)); auto vvv_ind = std::vector(n_dim, std::vector(2, std::vector<size_t>(out_elements)));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements)); std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
} }
}); });
std::vector<std::vector<std::size_t>> vec_dims(out_elements); auto ind = calc_neighbor_points(
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); vvv_ind, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s);
auto ind_lens = out_lens; auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim); ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens}; shape ind_s{shape::int32_type, ind_lens};
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,9 +35,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,9 +35,13 @@ inline namespace MIGRAPHX_INLINE_NS {
void optimize_module::apply(module_pass_manager& mpm) const void optimize_module::apply(module_pass_manager& mpm) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{
// loop to further optimize after initial transformations
for(int j = 0; j < 2; j++)
{ {
mpm.run_pass(simplify_reshapes{}); mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{}); mpm.run_pass(simplify_algebra{});
}
mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{}); mpm.run_pass(propagate_constant{});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx, ...@@ -52,6 +52,11 @@ void calculate_padding(int64_t idx,
} }
} }
/**
* Given the input array dimensions; kernel (wei_lens); strides; and dilations,
* calculate the padding value in each dimension.
*
*/
std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens, std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& wei_lens, const std::vector<std::size_t>& wei_lens,
const std::vector<std::size_t>& strides, const std::vector<std::size_t>& strides,
...@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input ...@@ -60,6 +65,7 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
{ {
std::vector<std::size_t> padding; std::vector<std::size_t> padding;
assert(input_lens.size() >= 3); assert(input_lens.size() >= 3);
assert(input_lens.size() == wei_lens.size());
std::size_t num_spatial_dims = input_lens.size() - 2; std::size_t num_spatial_dims = input_lens.size() - 2;
padding.resize(2 * num_spatial_dims); padding.resize(2 * num_spatial_dims);
for(std::size_t i = 0; i < num_spatial_dims; i++) for(std::size_t i = 0; i < num_spatial_dims; i++)
...@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input ...@@ -88,6 +94,11 @@ std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input
return padding; return padding;
} }
/**
* Calculate the correct output shape for a convolution with
* a given input size and other parameters.
*
*/
shape compute_padded_shape(const shape& input, shape compute_padded_shape(const shape& input,
const shape& weights, const shape& weights,
const std::vector<std::size_t>& padding, const std::vector<std::size_t>& padding,
...@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input, ...@@ -111,5 +122,33 @@ shape compute_padded_shape(const shape& input,
return input.with_lens(output_lens); return input.with_lens(output_lens);
} }
/**
* Calculate the correct output shape for a pooling with
* a given input size and other parameters. This uses
* the same formula for pooling that compute_padded_shape() uses
* for convolutions, but takes slightly different inputs.
*
*/
shape compute_padded_pool_shape(const shape& input,
const shape& kernel,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation)
{
const size_t num_spatial_dims = input.lens().size() - 2;
std::vector<size_t> output_lens{input.lens()[0], input.lens()[1]};
// calculate the output shape of the pooling: ((W - K + 2P) / S) + 1
for(size_t i = 0; i < num_spatial_dims; ++i)
{
auto padding_factor = padding[i] + padding[i + num_spatial_dims];
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[i + 2] - (1 + dilation[i] * (kernel.lens()[i] - 1)) + padding_factor) /
stride[i] +
1)));
}
return input.with_lens(output_lens);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -624,7 +624,7 @@ std::string get_migraphx_version() ...@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file. if any changes occur to the format of the MXR file.
*/ */
const int program_file_version = 6; const int program_file_version = 7;
value program::to_value() const value program::to_value() const
{ {
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propagate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
......
...@@ -50,13 +50,14 @@ struct shape_impl ...@@ -50,13 +50,14 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
...@@ -151,6 +152,22 @@ struct shape_impl ...@@ -151,6 +152,22 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const std::vector<std::size_t> min_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::size_t> ret(m_dyn_dims.size());
...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t) ...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
} }
MIGRAPHX_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
std::string shape::cpp_type(shape::type_t t) std::string shape::cpp_type(shape::type_t t)
{ {
switch(t) switch(t)
...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l))) : impl(std::make_shared<shape_impl>(t, std::move(l)))
{ {
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const ...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
else
{ return impl->get_index(i);
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
} }
std::vector<std::size_t> shape::multi(std::size_t idx) const std::vector<std::size_t> shape::multi(std::size_t idx) const
......
...@@ -1327,46 +1327,57 @@ struct find_split_reshape ...@@ -1327,46 +1327,57 @@ struct find_split_reshape
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front(); auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto split_outputs = get_splits(input); auto split_outputs = get_splits(input);
if(split_outputs.empty()) if(split_outputs.empty())
{ {
return; return;
} }
// Only want to apply this optimization if each split output is followed by // Find all the reshapes (similar to rsp) that can be simplified
// a contiguous op and a reshape std::vector<instruction_ref> conts;
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) { std::vector<instruction_ref> vec_rsp;
if(i->outputs().size() == 1)
// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for(auto& i : split_outputs)
{ {
auto cont = i->outputs().front(); std::copy_if(i->outputs().begin(),
return cont->outputs().size() != 1; i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
} }
return false;
})) for(auto& i : conts)
{ {
return; std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
} }
std::vector<instruction_ref> vec_rsp(split_outputs.size()); // No simplification needed if there is only one slice -> cont -> reshape
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { if(vec_rsp.size() <= 1)
auto cont = i->outputs().front();
return cont->outputs().front();
});
// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(not same_ops(vec_rsp))
{ {
return; return;
} }
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>()); slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
...@@ -1393,16 +1404,67 @@ struct find_split_reshape ...@@ -1393,16 +1404,67 @@ struct find_split_reshape
{ {
rsp_axis = std::distance(rsp_strides.begin(), ait); rsp_axis = std::distance(rsp_strides.begin(), ait);
} }
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { // Calculate reshape output shape
return is->get_shape().lens()[rsp_axis]; // Need to find a reshape such that data represented by instructions in vec_rsp can be
}); // written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); // cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard()) if(not input->get_shape().standard())
...@@ -1413,16 +1475,14 @@ struct find_split_reshape ...@@ -1413,16 +1475,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
m.replace_instruction( m.replace_instruction(
vec_rsp[i], vec_rsp[i],
make_op( make_op(
"slice", "slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), {{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins); rsp_ins);
start += vec_dims[i];
} }
} }
}; };
...@@ -1446,10 +1506,13 @@ struct find_split_transpose ...@@ -1446,10 +1506,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -784,7 +808,7 @@ struct find_transpose_slice ...@@ -784,7 +808,7 @@ struct find_transpose_slice
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 4; i++) for(int i = 0; i < depth; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS ...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device) ...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx) target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes) target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device) migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
...@@ -123,6 +125,7 @@ add_library(migraphx_gpu ...@@ -123,6 +125,7 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
mlir.cpp mlir.cpp
multinomial.cpp multinomial.cpp
no_device.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
......
...@@ -115,6 +115,12 @@ struct hiprtc_program ...@@ -115,6 +115,12 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
...@@ -130,6 +136,14 @@ struct hiprtc_program ...@@ -130,6 +136,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path)); include_names.push_back(std::move(src.path));
} }
} }
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(), cpp_name.c_str(),
headers.size(), headers.size(),
...@@ -137,7 +151,7 @@ struct hiprtc_program ...@@ -137,7 +151,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) const void compile(const std::vector<std::string>& options, bool quiet = false) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -148,7 +162,7 @@ struct hiprtc_program ...@@ -148,7 +162,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log(); auto prog_log = log();
if(not prog_log.empty()) if(not prog_log.empty() and not quiet)
{ {
std::cerr << prog_log << std::endl; std::cerr << prog_log << std::endl;
} }
...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input;
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
......
...@@ -91,9 +91,10 @@ __content__ ...@@ -91,9 +91,10 @@ __content__
return replace_string(args_hpp, "__content__", inner); return replace_string(args_hpp, "__content__", inner);
} }
const std::vector<std::string>& compiler_warnings() static std::vector<std::string> get_compiler_warnings()
{ {
static std::vector<std::string> warnings = {"-Weverything", std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat", "-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic", "-Wno-c++98-compat-pedantic",
"-Wno-conversion", "-Wno-conversion",
...@@ -112,7 +113,17 @@ const std::vector<std::string>& compiler_warnings() ...@@ -112,7 +113,17 @@ const std::vector<std::string>& compiler_warnings()
"-Wno-sign-compare", "-Wno-sign-compare",
"-Wno-unused-command-line-argument", "-Wno-unused-command-line-argument",
"-Wno-weak-vtables", "-Wno-weak-vtables",
"-Wno-c99-extensions"}; "-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings; return warnings;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment