Commit 350bbea2 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge branch 'develop' into resnet50_partition

parents 848a476d 74ba9649
...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f) ...@@ -205,7 +205,7 @@ void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
} }
template <class Range> template <class Range>
auto reverse(Range& r) auto reverse(Range&& r)
{ {
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin())); return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
} }
......
...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -263,7 +263,7 @@ struct MIGRAPHX_EXPORT shape
/// no padding /// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true if the shape has been transposed. That is the strides are not in descending
/// order /// order
bool transposed() const; bool transposed() const;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,11 +37,11 @@ inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
std::vector<std::size_t> indices(s.lens().size()); std::vector<std::size_t> indices(s.lens().size());
const auto& index_const_ref = indices;
shape ss{s.type(), s.lens()}; shape ss{s.type(), s.lens()};
for(std::size_t i = 0; i < ss.elements(); i++) size_t max = ss.elements();
for(std::size_t i = 0; i < max; i++)
{ {
std::transform(ss.strides().begin(), std::transform(ss.strides().begin(),
ss.strides().end(), ss.strides().end(),
...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,9 +51,13 @@ void shape_for_each(const migraphx::shape& s, F f)
assert(len > 0 and stride > 0); assert(len > 0 and stride > 0);
return (i / stride) % len; return (i / stride) % len;
}); });
call(indices); if constexpr(std::is_invocable<F, decltype(index_const_ref), decltype(i)>{})
f(index_const_ref, i);
else
f(index_const_ref);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -86,7 +86,7 @@ inline std::string join_strings(Strings strings, const std::string& delim) ...@@ -86,7 +86,7 @@ inline std::string join_strings(Strings strings, const std::string& delim)
inline std::vector<std::string> split_string(const std::string& s, char delim) inline std::vector<std::string> split_string(const std::string& s, char delim)
{ {
std::vector<std::string> elems; std::vector<std::string> elems;
std::stringstream ss(s + ' '); std::stringstream ss(s + delim);
std::string item; std::string item;
while(std::getline(ss, item, delim)) while(std::getline(ss, item, delim))
{ {
...@@ -149,6 +149,10 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std: ...@@ -149,6 +149,10 @@ interpolate_string(const std::string& input, F f, std::string start = "${", std:
result.append(it, next_start); result.append(it, next_start);
if(next_start == input.end()) if(next_start == input.end())
break; break;
if(next_end == input.end())
{
throw std::runtime_error("Unbalanced brackets");
}
auto r = f(next_start + start.size(), next_end); auto r = f(next_start + start.size(), next_end);
result.append(r.begin(), r.end()); result.append(r.begin(), r.end());
it = next_end + end.size(); it = next_end + end.size();
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/instruction.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt ...@@ -60,9 +61,29 @@ void save(const program& p, const std::string& filename, const file_options& opt
{ {
write_buffer(filename, save_buffer(p, options)); write_buffer(filename, save_buffer(p, options));
} }
// MIOpen doesn't support serializing fusion plans with Find-2.0 APIs
void print_miopen_warning(const program& p)
{
auto mods = p.get_modules();
if(std::any_of(mods.begin(), mods.end(), [](const auto* m) {
return std::any_of(m->begin(), m->end(), [](const instruction& i) {
return i.name() == "gpu::miopen_fusion";
});
}))
{
std::cout << "[WARNING]: Program has miopen_fusion instructions for which tuned solutions "
"are not stored inside serialized MIGraphX program. Consider serializing with "
"MIGRAPHX_DISABLE_MIOPEN_FUSION=1 flag set."
<< std::endl;
;
}
}
std::vector<char> save_buffer(const program& p, const file_options& options) std::vector<char> save_buffer(const program& p, const file_options& options)
{ {
value v = p.to_value(); value v = p.to_value();
print_miopen_warning(p);
std::vector<char> buffer; std::vector<char> buffer;
if(options.format == "msgpack") if(options.format == "msgpack")
{ {
......
...@@ -25,6 +25,33 @@ ...@@ -25,6 +25,33 @@
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <msgpack.hpp> #include <msgpack.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Leave an extra byte for error checking
constexpr std::size_t msgpack_size_limit = std::numeric_limits<uint32_t>::max() - 1;
template <class Range>
std::size_t msgpack_chunk_size(const Range& r)
{
return 1 + (r.size() - 1) / msgpack_size_limit;
}
template <class Iterator, class F>
void msgpack_chunk_for_each(Iterator start, Iterator last, F f)
{
while(std::distance(start, last) > msgpack_size_limit)
{
auto next = std::next(start, msgpack_size_limit);
f(start, next);
start = next;
}
f(start, last);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace msgpack { namespace msgpack {
MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -63,16 +90,31 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
break; break;
} }
case msgpack::type::BIN: { case msgpack::type::BIN: {
// For backwards compatibility
v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size}; v = migraphx::value::binary{o.via.bin.ptr, o.via.bin.size};
break; break;
} }
case msgpack::type::ARRAY: { case msgpack::type::ARRAY: {
migraphx::value r = migraphx::value::array{}; if(o.via.array.size != 0 and o.via.array.ptr->type == msgpack::type::BIN)
std::for_each( {
o.via.array.ptr, auto bin = migraphx::value::binary{};
o.via.array.ptr + o.via.array.size, std::for_each(
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); }); o.via.array.ptr,
v = r; o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) {
bin.insert(bin.end(), so.via.bin.ptr, so.via.bin.ptr + so.via.bin.size);
});
v = bin;
}
else
{
migraphx::value r = migraphx::value::array{};
std::for_each(
o.via.array.ptr,
o.via.array.ptr + o.via.array.size,
[&](const msgpack::object& so) { r.push_back(so.as<migraphx::value>()); });
v = r;
}
break; break;
} }
case msgpack::type::MAP: { case msgpack::type::MAP: {
...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -102,8 +144,12 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
{ {
const auto* data = reinterpret_cast<const char*>(x.data()); const auto* data = reinterpret_cast<const char*>(x.data());
auto size = x.size(); auto size = x.size();
o.pack_bin(size); o.pack_array(migraphx::msgpack_chunk_size(x));
o.pack_bin_body(data, size); migraphx::msgpack_chunk_for_each(
data, data + size, [&](const char* start, const char* last) {
o.pack_bin(last - start);
o.pack_bin_body(start, last - start);
});
return o; return o;
} }
}; };
...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) ...@@ -129,6 +175,8 @@ MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS)
o.pack_array(0); o.pack_array(0);
return; return;
} }
if(v.size() > migraphx::msgpack_size_limit)
MIGRAPHX_THROW("Size is too large for msgpack");
if(not v.front().get_key().empty()) if(not v.front().get_key().empty())
{ {
o.pack_map(v.size()); o.pack_map(v.size());
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode) ...@@ -97,22 +97,19 @@ const auto& get_original_idx_op(const std::string& mode)
static std::vector<int> static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind, calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
int i_dim, int i_dim,
const std::vector<std::vector<std::size_t>>& vec_dims, std::vector<std::vector<std::size_t>> vec_dims,
const shape& in_s) const shape& in_s)
{ {
if(i_dim == vvv_ind.size()) if(i_dim == vvv_ind.size())
{ {
std::vector<int> vec_ind; std::vector<int> vec_ind(vec_dims.size());
vec_ind.resize(vec_dims.size());
std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) { std::transform(vec_dims.begin(), vec_dims.end(), vec_ind.begin(), [&](auto idx) {
return static_cast<int>(in_s.index(idx)); return static_cast<int>(in_s.index(idx));
}); });
return vec_ind; return vec_ind;
} }
const auto& vv_ind = vvv_ind[i_dim]; const auto& vv_lo = vvv_ind[i_dim][0];
const auto& vv_lo = vv_ind.at(0);
std::vector<std::vector<std::size_t>> vec_dims1; std::vector<std::vector<std::size_t>> vec_dims1;
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size())
{ {
...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -126,8 +123,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
}); });
} }
const auto& vv_hi = vv_ind.at(1); const auto& vv_hi = vvv_ind[i_dim][1];
for(std::size_t start = 0; start < vec_dims.size(); start += vv_lo.size()) for(std::size_t start = 0; start < vec_dims.size(); start += vv_hi.size())
{ {
std::transform(vv_hi.begin(), std::transform(vv_hi.begin(),
vv_hi.end(), vv_hi.end(),
...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v ...@@ -138,8 +135,8 @@ calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& v
return dim; return dim;
}); });
} }
vec_dims.clear();
return calc_neighbor_points(vvv_ind, i_dim + 1, vec_dims1, in_s); return calc_neighbor_points(vvv_ind, i_dim + 1, std::move(vec_dims1), in_s);
} }
static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -240,7 +237,7 @@ struct parse_resize : op_parser<parse_resize>
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!"); "PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -267,7 +264,7 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_" + opd.op_name + "PARSE_" + opd.op_name +
": dynamic input scale is not supported!"); ": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize> ...@@ -300,15 +297,15 @@ struct parse_resize : op_parser<parse_resize>
// map out_idx to in_idx // map out_idx to in_idx
auto nearest_op = get_nearest_op(nearest_mode); auto nearest_op = get_nearest_op(nearest_mode);
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx; std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val); in_idx[ii] = nearest_op(in_lens[ii], idx_val);
} }
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx)); ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
}); });
shape ind_s{shape::int32_type, out_lens}; shape ind_s{shape::int32_type, out_lens};
...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize> ...@@ -323,24 +320,21 @@ struct parse_resize : op_parser<parse_resize>
// get the number of dimensions // get the number of dimensions
std::size_t n_dim = out_lens.size(); std::size_t n_dim = out_lens.size();
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements)); auto vvv_ind = std::vector(n_dim, std::vector(2, std::vector<size_t>(out_elements)));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements)); std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape_for_each(out_s, [&](auto idx) { shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
auto in_idx = idx;
auto out_idx = out_s.index(idx);
for(auto ii = 0; ii < in_lens.size(); ++ii) for(auto ii = 0; ii < in_lens.size(); ++ii)
{ {
auto idx_val = idx_op(in_lens[ii], out_lens[ii], idx[ii], vec_scale[ii]); auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
} }
}); });
std::vector<std::vector<std::size_t>> vec_dims(out_elements); auto ind = calc_neighbor_points(
auto ind = calc_neighbor_points(vvv_ind, 0, vec_dims, in_s); vvv_ind, 0, std::vector<std::vector<std::size_t>>(out_elements), in_s);
auto ind_lens = out_lens; auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim); ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens}; shape ind_s{shape::int32_type, ind_lens};
......
...@@ -629,7 +629,7 @@ std::string get_migraphx_version() ...@@ -629,7 +629,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file. if any changes occur to the format of the MXR file.
*/ */
const int program_file_version = 6; const int program_file_version = 7;
value program::to_value() const value program::to_value() const
{ {
......
...@@ -50,13 +50,14 @@ struct shape_impl ...@@ -50,13 +50,14 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
...@@ -151,6 +152,22 @@ struct shape_impl ...@@ -151,6 +152,22 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const std::vector<std::size_t> min_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::size_t> ret(m_dyn_dims.size());
...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t) ...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
} }
MIGRAPHX_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
std::string shape::cpp_type(shape::type_t t) std::string shape::cpp_type(shape::type_t t)
{ {
switch(t) switch(t)
...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l))) : impl(std::make_shared<shape_impl>(t, std::move(l)))
{ {
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const ...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
else
{ return impl->get_index(i);
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
} }
std::vector<std::size_t> shape::multi(std::size_t idx) const std::vector<std::size_t> shape::multi(std::size_t idx) const
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -115,6 +115,12 @@ struct hiprtc_program ...@@ -115,6 +115,12 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
...@@ -130,6 +136,14 @@ struct hiprtc_program ...@@ -130,6 +136,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path)); include_names.push_back(std::move(src.path));
} }
} }
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(), cpp_name.c_str(),
headers.size(), headers.size(),
...@@ -137,7 +151,7 @@ struct hiprtc_program ...@@ -137,7 +151,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) const void compile(const std::vector<std::string>& options, bool quiet = false) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -148,7 +162,7 @@ struct hiprtc_program ...@@ -148,7 +162,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log(); auto prog_log = log();
if(not prog_log.empty()) if(not prog_log.empty() and not quiet)
{ {
std::cerr << prog_log << std::endl; std::cerr << prog_log << std::endl;
} }
...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input;
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
......
...@@ -91,28 +91,39 @@ __content__ ...@@ -91,28 +91,39 @@ __content__
return replace_string(args_hpp, "__content__", inner); return replace_string(args_hpp, "__content__", inner);
} }
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings() const std::vector<std::string>& compiler_warnings()
{ {
static std::vector<std::string> warnings = {"-Weverything", static std::vector<std::string> warnings = get_compiler_warnings();
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
return warnings; return warnings;
} }
......
...@@ -103,7 +103,10 @@ struct mlir_op ...@@ -103,7 +103,10 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
return ins_shapes[ins->inputs().at(0)].with_type(type); auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -119,6 +122,33 @@ struct mlir_op ...@@ -119,6 +122,33 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true; return true;
} }
struct find_mlir_op struct find_mlir_fused_ops
{ {
auto matcher() const auto matcher() const
{ {
...@@ -163,34 +193,6 @@ struct find_mlir_op ...@@ -163,34 +193,6 @@ struct find_mlir_op
return ins_map; return ins_map;
} }
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints // Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function) // for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types. // on particular types.
...@@ -236,8 +238,7 @@ struct find_mlir_op ...@@ -236,8 +238,7 @@ struct find_mlir_op
"log", "log",
"recip", "recip",
"rsqrt", "rsqrt",
// There are bugs in MLIR right now for models using sigmoid so disable it for now "sigmoid",
// "sigmoid",
"softmax", "softmax",
"tanh", "tanh",
}; };
...@@ -301,14 +302,95 @@ struct find_mlir_op ...@@ -301,14 +302,95 @@ struct find_mlir_op
} }
}; };
struct find_mlir_standalone_convolution_op
{
auto matcher() const { return match::name("convolution"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option)
{
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool is_fusion_enabled()
{
if(is_self_decide())
{
return true;
}
return is_requested("fused");
}
bool is_standalone_convs_enabled(context* ctx)
{
if(is_self_decide())
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
return is_requested("convolution");
}
} // namespace } // namespace
#endif #endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{}); if(is_fusion_enabled())
{
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_standalone_convs_enabled(this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -58,6 +58,8 @@ struct hiprtc_src_file ...@@ -58,6 +58,8 @@ struct hiprtc_src_file
} }
}; };
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc( MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch); std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
......
...@@ -84,8 +84,10 @@ struct miopen_convolution ...@@ -84,8 +84,10 @@ struct miopen_convolution
{ {
check_shapes{inputs, op}.has(4); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts( check_shapes{conv_inputs, *this}
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); .max_ndims(5)
.packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}})
.same_layout();
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
...@@ -69,6 +70,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -69,6 +70,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -684,8 +686,10 @@ struct mlir_program ...@@ -684,8 +686,10 @@ struct mlir_program
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
mlir_tuning_space params{ auto tuning_mode = RocmlirTuningParamSetKindFull;
mlirRockTuningSpaceCreate(mmodule.get(), RocmlirTuningParamSetKindFull)}; if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get()))) for(auto i : range(mlirRockTuningGetNumParams(params.get())))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
...@@ -719,7 +723,8 @@ struct mlir_program ...@@ -719,7 +723,8 @@ struct mlir_program
if(not tuning_cfg_path.empty()) if(not tuning_cfg_path.empty())
{ {
std::vector<std::string> tokens = split_string(prob_config, '\t'); std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1]; std::string prob = tokens[2];
if(starts_with(prob, "conv")) if(starts_with(prob, "conv"))
{ {
tuning_cfg_path += ".conv"; tuning_cfg_path += ".conv";
...@@ -729,6 +734,8 @@ struct mlir_program ...@@ -729,6 +734,8 @@ struct mlir_program
tuning_cfg_path += ".gemm"; tuning_cfg_path += ".gemm";
} }
std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app); std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app);
prob =
trim(prob, [](unsigned char c) { return (c == '\0') or (std::isspace(c) != 0); });
tuning_cfg << prob << std::endl; tuning_cfg << prob << std::endl;
} }
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -78,16 +78,6 @@ bool verify_args(const std::string& name, ...@@ -78,16 +78,6 @@ bool verify_args(const std::string& name,
if(verify::range_zero(target)) if(verify::range_zero(target))
std::cout << "Target data is all zeros" << std::endl; std::cout << "Target data is all zeros" << std::endl;
// auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < verify::range_distance(ref))
// {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl;
// }
auto ref_nan_idx = find_idx(ref, verify::not_finite); auto ref_nan_idx = find_idx(ref, verify::not_finite);
if(ref_nan_idx >= 0) if(ref_nan_idx >= 0)
std::cout << "Non finite number found in ref at " << ref_nan_idx << ": " std::cout << "Non finite number found in ref at " << ref_nan_idx << ": "
...@@ -97,7 +87,7 @@ bool verify_args(const std::string& name, ...@@ -97,7 +87,7 @@ bool verify_args(const std::string& name,
if(target_nan_idx >= 0) if(target_nan_idx >= 0)
std::cout << "Non finite number found in target at " << target_nan_idx << ": " std::cout << "Non finite number found in target at " << target_nan_idx << ": "
<< target[target_nan_idx] << std::endl; << target[target_nan_idx] << std::endl;
// std::cout << std::endl; std::cout << "MIGraphX verification passed successfully." << std::endl;
} }
}); });
return passed; return passed;
......
...@@ -31,6 +31,11 @@ set(CTEST_PARALLEL_LEVEL ${N} CACHE STRING "CTest parallel level") ...@@ -31,6 +31,11 @@ set(CTEST_PARALLEL_LEVEL ${N} CACHE STRING "CTest parallel level")
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -j ${CTEST_PARALLEL_LEVEL} -C ${CMAKE_CFG_INTDIR} --timeout 5000) add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -j ${CTEST_PARALLEL_LEVEL} -C ${CMAKE_CFG_INTDIR} --timeout 5000)
add_custom_target(tests) add_custom_target(tests)
set(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL "")
if(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS)
add_compile_definitions(MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS)
endif()
find_program(MIGRAPHX_GDB gdb) find_program(MIGRAPHX_GDB gdb)
if(MIGRAPHX_GDB) if(MIGRAPHX_GDB)
......
...@@ -31,24 +31,39 @@ ...@@ -31,24 +31,39 @@
using migraphx::shape; using migraphx::shape;
bool create_shapes(bool dynamic_allowed) void create_shapes(bool dynamic_allowed)
{ {
try shape a{shape::int64_type, {3}};
{ shape b{shape::float_type, {{3, 6}, {4, 4}}};
shape a{shape::int64_type, {3}}; migraphx::check_shapes{{a, b}, "", dynamic_allowed}.has(2);
shape b{shape::float_type, {{3, 6}, {4, 4}}};
auto op = migraphx::make_op("add");
migraphx::check_shapes{{a, b}, op, dynamic_allowed}.has(2);
return true;
}
catch(...)
{
return false;
}
} }
TEST_CASE(allow_dynamic_shape) { EXPECT(create_shapes(true)); } TEST_CASE(allow_dynamic_shape)
{
EXPECT(not test::throws([] { create_shapes(true); }));
}
TEST_CASE(fail_dynamic_shape)
{
EXPECT(test::throws([] { create_shapes(false); }));
}
TEST_CASE(fail_dynamic_shape) { EXPECT(not create_shapes(false)); } TEST_CASE(same_layout_fail)
{
EXPECT(test::throws([] {
shape a{shape::float_type, {2, 3}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}
TEST_CASE(same_layout_pass)
{
EXPECT(not test::throws([] {
shape a{shape::float_type, {2, 3}, {1, 2}};
shape b{shape::float_type, {2, 3}, {1, 2}};
migraphx::check_shapes{{a, b}, ""}.same_layout();
}));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment