Unverified Commit 9c91c08d authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents a56bb11d c1b8c975
...@@ -66,7 +66,7 @@ struct convert : unary<convert> ...@@ -66,7 +66,7 @@ struct convert : unary<convert>
auto type = target_type; auto type = target_type;
return [type](auto x) { return [type](auto x) {
auto y = x; auto y = x;
shape::visit(type, [&](auto as) { y = std::min(std::max(as(x), as.min()), as.max()); }); shape::visit(type, [&](auto as) { y = as(x); });
return y; return y;
}; };
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,9 +36,9 @@ namespace op { ...@@ -36,9 +36,9 @@ namespace op {
/** /**
* Broadcast multiple dimensions between two tensors. * Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs. * Two versions of this operator: 1 input and 2+ inputs.
* One input version uses output_lens attribute and broadcasts to it. * One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time. * 2+ inputs version broadcasts first input to the common shape at evaluation time.
*/ */
struct multibroadcast struct multibroadcast
{ {
...@@ -57,12 +57,12 @@ struct multibroadcast ...@@ -57,12 +57,12 @@ struct multibroadcast
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1, 2); check_shapes{inputs, *this, true}.has_at_least(1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto s0 = inputs.at(0); auto s0 = inputs.at(0);
if(s0.max_lens().empty()) if(s0.ndim() < 1)
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
} }
...@@ -81,6 +81,9 @@ struct multibroadcast ...@@ -81,6 +81,9 @@ struct multibroadcast
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
if(s0.dynamic())
MIGRAPHX_THROW(
"MULTIBROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(s0.lens().size() > output_lens.size()) if(s0.lens().size() > output_lens.size())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
...@@ -102,19 +105,20 @@ struct multibroadcast ...@@ -102,19 +105,20 @@ struct multibroadcast
} }
else else
{ {
// two inputs // 2+ inputs
auto s1 = inputs.at(1); if(std::any_of(
if(s0.dynamic() or s1.dynamic()) inputs.cbegin(), inputs.cend(), [](auto input) { return input.dynamic(); }))
{ {
if(not output_dyn_dims.empty()) if(not output_dyn_dims.empty())
{ {
return {t, output_dyn_dims}; return {t, output_dyn_dims};
} }
return {t, compute_broadcasted_dyn_dims(s0, s1)}; return {t, compute_common_dyn_dims(inputs)};
} }
else else
{ {
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens()); // output_lens will not be set for 2+ input version
auto bcast_lens = compute_common_lens(inputs);
auto offset = bcast_lens.size() - s0.lens().size(); auto offset = bcast_lens.size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(bcast_lens, offset); auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)}; return {t, std::move(bcast_lens), std::move(bcast_strides)};
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -96,9 +97,115 @@ struct reshape ...@@ -96,9 +97,115 @@ struct reshape
return {s0.type(), output_dyn_dims}; return {s0.type(), output_dyn_dims};
} }
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{ {
check_shapes{inputs, *this}.standard(); check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
...@@ -125,12 +232,17 @@ struct reshape ...@@ -125,12 +232,17 @@ struct reshape
} }
} }
shape s{inputs.front().type(), rdims}; auto s = reshape_dims(inputs.front(), rdims);
if(s.elements() != inputs.front().elements()) if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s.elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
return s;
assert(s->bytes() == inputs.front().bytes());
return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -261,11 +261,13 @@ auto compute_op(rank<1>, ...@@ -261,11 +261,13 @@ auto compute_op(rank<1>,
template <class T, class F> template <class T, class F>
argument compute_op(rank<0>, argument compute_op(rank<0>,
const T& x, const T& x,
const shape&, const shape& output,
const std::vector<argument>&, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>& module_args,
F) F)
{ {
if(module_args.empty())
return compute_op(x, output, inputs);
std::string name = x.name(); std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name); MIGRAPHX_THROW("Not computable: " + name);
} }
......
...@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op) ...@@ -56,12 +56,12 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
} }
/*! /*!
* Returns the permutation needed to apply to the shape to undo the current permutation * Returns the inverse permutation that could be applied to undo the inputted permutation
*/ */
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation); std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*! /*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape. * Finds the permutation that would make the shape not transposed (refering to shape.transposed())
*/ */
std::vector<int64_t> find_permutation(const shape& s); std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes); std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
...@@ -79,6 +79,9 @@ struct program ...@@ -79,6 +79,9 @@ struct program
std::vector<argument> eval(parameter_map params, std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const; execution_environment exec_env = execution_environment{}) const;
void finish() const;
std::size_t size() const; std::size_t size() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
......
...@@ -187,6 +187,7 @@ struct raw_data : raw_data_base ...@@ -187,6 +187,7 @@ struct raw_data : raw_data_base
std::string to_string() const std::string to_string() const
{ {
std::stringstream ss; std::stringstream ss;
ss.precision(std::numeric_limits<double>::max_digits10);
ss << static_cast<const Derived&>(*this); ss << static_cast<const Derived&>(*this);
return ss.str(); return ss.str();
} }
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module_pass_manager;
/** /**
* Replace `allocate` instructions with target allocations or output parameters. * Replace `allocate` instructions with target allocations or output parameters.
...@@ -40,7 +40,7 @@ struct replace_allocate ...@@ -40,7 +40,7 @@ struct replace_allocate
allocation_model model; allocation_model model;
bool offload_copy = false; bool offload_copy = false;
std::string name() const { return "replace_allocate"; } std::string name() const { return "replace_allocate"; }
void apply(module& m) const; void apply(module_pass_manager& mpm) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -156,14 +156,34 @@ struct shape ...@@ -156,14 +156,34 @@ struct shape
shape(const std::vector<shape>& subs); shape(const std::vector<shape>& subs);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static shape static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm); from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*! /*!
* The number of dimensions in the shape. * The number of dimensions in the shape, either static or dynamic.
* Same as the number of indices required to get a data value. * Same as the number of indices required to get a data value.
*/ */
std::size_t ndim() const; std::size_t ndim() const;
...@@ -279,6 +299,8 @@ struct shape ...@@ -279,6 +299,8 @@ struct shape
type min() const { return std::numeric_limits<type>::lowest(); } type min() const { return std::numeric_limits<type>::lowest(); }
type nan() const { return std::numeric_limits<type>::quiet_NaN(); }
template <class U> template <class U>
type operator()(U u) const type operator()(U u) const
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_SOURCE_LOCATION
using source_location = std::source_location;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using source_location = std::experimental::source_location;
#else
struct source_location
{
static constexpr source_location current() noexcept { return source_location{}; }
constexpr std::uint_least32_t line() const noexcept { return 0; }
constexpr std::uint_least32_t column() const noexcept { return 0; }
constexpr const char* file_name() const noexcept { return ""; }
constexpr const char* function_name() const noexcept { return ""; }
};
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for a compilation target /// An interface for a compilation target
...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x) ...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <cstdint>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <tuple> #include <tuple>
...@@ -392,8 +393,8 @@ struct value ...@@ -392,8 +393,8 @@ struct value
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
} }
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
...@@ -461,6 +462,8 @@ struct value ...@@ -461,6 +462,8 @@ struct value
friend std::ostream& operator<<(std::ostream& os, const value& d); friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
type_t get_type() const; type_t get_type() const;
...@@ -481,4 +484,15 @@ struct value ...@@ -481,4 +484,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const { return x.hash(); }
};
} // namespace std
#endif #endif
...@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const ...@@ -473,7 +473,9 @@ operation instruction::normalized_operator() const
return o; return o;
} }
std::size_t instruction::get_target_id() const { return target_id; } std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; } void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args) std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{ {
std::vector<shape> shapes(args.size()); std::vector<shape> shapes(args.size());
......
...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref ...@@ -326,6 +326,8 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
if(ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// "rep" instruction could be used earlier in the program and moving it at the end
// may cause invalid program, therefore make an identity operation in this case.
return replace_instruction(ins, make_op("identity"), rep); return replace_instruction(ins, make_op("identity"), rep);
} }
...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const ...@@ -650,8 +652,9 @@ instruction_ref module::find_dangling_reference() const
return end(); return end();
} }
void module::finalize(context& ctx) void module::finalize(std::vector<context>& contexts)
{ {
assert(not contexts.empty());
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{}); const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
...@@ -660,10 +663,10 @@ void module::finalize(context& ctx) ...@@ -660,10 +663,10 @@ void module::finalize(context& ctx)
std::cout << "Finalize: "; std::cout << "Finalize: ";
this->debug_print(ins); this->debug_print(ins);
} }
ins->finalize(ctx); ins->finalize(contexts[ins->get_target_id()]);
for(const auto& smod : ins->module_inputs()) for(const auto& smod : ins->module_inputs())
{ {
smod->finalize(ctx); smod->finalize(contexts);
} }
} }
......
...@@ -38,6 +38,9 @@ ...@@ -38,6 +38,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_ONNX_PARSER)
static shape shape_from_dyn_dims(shape::type_t shape_type, static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims) const std::vector<shape::dynamic_dimension>& dyn_dims)
...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type, ...@@ -53,8 +56,6 @@ static shape shape_from_dyn_dims(shape::type_t shape_type,
return {shape_type, dyn_dims}; return {shape_type, dyn_dims};
} }
namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::AttributeProto> result;
...@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -149,6 +150,25 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
return this->add_common_op(op_name, arg0, arg1); return this->add_common_op(op_name, arg0, arg1);
} }
/**
* @brief A wrapper for insert_common_args(), which constructs an argument list
* and inserts multibroadcast and convert ops to match inputs to a common shape and type
* as required. The requested operation is placed after the added multibroadcast and convert ops,
* if any, so that their results are transparent to the programmer.
*
* Use add_common_op() to match input sizes when inputs may be
* either static or dynamic.
*
* @param op_name string; Name of operation (op) to add; valid names are the same as
* for make_op()
*
* @param inputs vector of instruction_ref. List of instructions for the new
* operator. Multibroadcast and convert operations, if needed, are deduced from these too.
*
* @return instruction_ref Returns an instruction_ref which is the result of the requested
* operation.
*
*/
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name, instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const std::vector<instruction_ref> inputs) const
{ {
...@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -278,16 +298,48 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
std::vector<instruction_ref> void print_added_instructions(module* mod,
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining) const std::vector<instruction_ref>& args,
const std::vector<instruction_ref>& result)
{
// Print instructions added by the parser not in args
std::vector<instruction_ref> added_instructions;
fix([&](auto self, auto r) {
for(auto ins : r)
{
if(contains(args, ins))
continue;
if(contains(added_instructions, ins))
continue;
self(ins->inputs());
added_instructions.push_back(ins);
}
})(result);
mod->debug_print(added_instructions);
}
std::unordered_map<std::string, instruction_ref>
parse_intializer(const onnx_parser& parser, module* mod, const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "initializer: " << f.name() << std::endl;
// backup instructions in parent mod // backup instructions in parent mod
mod_insts[f.name()] = mod->add_literal(parse_tensor(f)); mod_insts[f.name()] = mod->add_literal(parser.parse_tensor(f));
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
mod->debug_print(mod_insts[f.name()]);
} }
return mod_insts;
}
std::unordered_map<std::string, instruction_ref>
parse_inputs(const onnx_parser& parser,
module* mod,
const onnx::GraphProto& graph,
std::unordered_map<std::string, instruction_ref> mod_insts)
{
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
...@@ -298,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -298,7 +350,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
if(contains(instructions, name)) if(contains(parser.instructions, name))
{ {
MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name + MIGRAPHX_THROW("module \"" + mod->name() + "\" has parameter name \"" + name +
"\" existing in parent graph!"); "\" existing in parent graph!");
...@@ -306,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -306,28 +358,41 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
shape s; shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(parser.map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = parser.map_input_dims.at(name);
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name)); s = shape_from_dyn_dims(shape_type, parser.map_dyn_input_dims.at(name));
} }
else else
{ {
s = parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
return mod_insts;
}
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts =
parse_intializer(*this, mod, graph);
mod_insts = parse_inputs(*this, mod, graph, mod_insts);
std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end())); std::copy(mod_insts.begin(), mod_insts.end(), std::inserter(instructions, instructions.end()));
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
std::cout << "operator: " << node.op_type() << std::endl;
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
for(auto&& input : node.input()) for(auto&& input : node.input())
{ {
...@@ -365,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -365,6 +430,11 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
result.begin(), result.begin(),
std::inserter(instructions, instructions.end()), std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(x, y); }); [](auto&& x, auto&& y) { return std::make_pair(x, y); });
if(enabled(MIGRAPHX_TRACE_ONNX_PARSER{}))
{
print_added_instructions(mod, args, result);
}
} }
// Find instructions corresponding to the output // Find instructions corresponding to the output
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,10 +21,14 @@ ...@@ -21,10 +21,14 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <iterator>
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT);
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,54 +43,105 @@ struct parse_instancenorm : op_parser<parse_instancenorm> ...@@ -39,54 +43,105 @@ struct parse_instancenorm : op_parser<parse_instancenorm>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> oargs) const
{ {
// y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
// mean = reduce_mean({D1, D2, ... Dk}, x) // mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2) // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// Convert fp16 to fp32 to workaround for FP16 accuracy issues with reduce_mean/variance.
bool convert_fp16 = true;
if(enabled(MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT{}))
{
convert_fp16 = false;
}
float epsilon = 1e-5f; float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
auto dtype = oargs[0]->get_shape().type();
auto literal_dtype = dtype;
std::vector<instruction_ref> args;
// cppcheck-suppress knownConditionTrueFalse
if(dtype == shape::half_type and convert_fp16)
{
std::transform(oargs.begin(), oargs.end(), std::back_inserter(args), [&](const auto i) {
return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), i);
});
literal_dtype = shape::float_type;
}
else
{
args = oargs;
}
auto x = args[0]; auto x = args[0];
auto scale = args[1]; auto scale = args[1];
auto bias = args[2]; auto bias = args[2];
auto dims = x->get_shape().lens(); auto dims = x->get_shape().lens();
auto dtype = x->get_shape().type();
if(not contains(valid_types, dtype)) if(not contains(valid_types, dtype))
MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) + MIGRAPHX_THROW(opd.op_name + ": invalid output type: " + std::to_string(dtype) +
". Valid types are 1 (float), 10 (half), and 11 (double)."); ". Valid types are 1 (float), 10 (half), and 11 (double).");
auto ndims = dims.size(); bool dyn_input = x->get_shape().dynamic();
auto ndims = x->get_shape().ndim();
assert(ndims >= 2); assert(ndims >= 2);
auto kdims = ndims - 2; auto kdims = ndims - 2;
std::vector<int64_t> axes(kdims); std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), 2); std::iota(axes.begin(), axes.end(), 2);
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x); auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto mean_bcast =
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), mean); // Use add_common_op() to insert multibroadcast/convert instructions where needed when
auto l0 = info.add_instruction(make_op("sqdiff"), x, mean_bcast); // inputs may be either static or dynamic.
auto variance = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0); auto l1 = info.add_common_op("sub", x, mean);
auto l1 = info.add_instruction(make_op("sub"), x, mean_bcast); // for the fp16, if not converting to fp32 then divide `x` and `mean` by `sqrt(n)` and take
auto epsilon_literal = info.add_literal(literal{shape{dtype}, {epsilon}}); // reduce_sum to calculate variance i.e.
auto epsilon_bcast = // var = reduce_sum((x/s_n - mean/s_n)^2) where s_n = sqrt(n)
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), epsilon_literal); std::string reduce_op_name =
auto variance_bcast = (dtype == shape::half_type and not convert_fp16) ? "reduce_sum" : "reduce_mean";
info.add_instruction(make_op("multibroadcast", {{"out_lens", dims}}), variance); if(dtype == shape::half_type and not convert_fp16)
auto l2 = info.add_instruction(make_op("add"), variance_bcast, epsilon_bcast); {
double n =
std::accumulate(dims.begin() + 2, dims.end(), 1, [&](const auto& i, const auto& j) {
return i * j;
});
n = 1.0 / std::sqrt(n);
auto n_literal = info.add_literal(literal{dtype, {n}});
x = info.add_common_op("mul", {x, n_literal});
}
auto l0 = info.add_common_op("sqdiff", x, mean);
auto variance = info.add_instruction(make_op(reduce_op_name, {{"axes", axes}}), l0);
auto epsilon_literal = info.add_literal(literal{shape{literal_dtype}, {epsilon}});
auto l2 = info.add_common_op("add", variance, epsilon_literal);
auto l3 = info.add_instruction(make_op("rsqrt"), l2); auto l3 = info.add_instruction(make_op("rsqrt"), l2);
auto l4 = info.add_instruction(make_op("mul"), l1, l3); auto l4 = info.add_common_op("mul", l1, l3);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale); // add_common_op() doesn't apply the plain broadcast op, so we add that op explicitly for
; // both scale and bias.
auto bias_bcast = instruction_ref scale_bcast;
instruction_ref bias_bcast;
if(dyn_input)
{
scale_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), scale, x);
bias_bcast = info.add_instruction(make_op("broadcast", {{"axis", 1}}), bias, x);
}
else
{
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias); info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
}
auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast); auto l5 = info.add_instruction(make_op("mul"), l4, scale_bcast);
return info.add_instruction(make_op("add"), l5, bias_bcast); auto ret = info.add_instruction(make_op("add"), l5, bias_bcast);
if(dtype == shape::half_type and convert_fp16)
{
return info.add_instruction(make_op("convert", {{"target_type", shape::half_type}}),
ret);
}
return ret;
} }
}; };
......
...@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where> ...@@ -56,6 +56,7 @@ struct parse_where : op_parser<parse_where>
auto lens = auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens) if(args[0]->get_shape().lens() != lens)
{ {
args[0] = args[0] =
......
...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -68,12 +68,18 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod = nullptr; module* mod = nullptr;
module* root_mod = nullptr;
tracer* t = nullptr; tracer* t = nullptr;
module* common_parent = nullptr; module* common_parent = nullptr;
program* prog = nullptr; program* prog = nullptr;
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {} module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
module_pm(module* pmod = nullptr, module* rmod = nullptr, tracer* pt = nullptr)
: mod(pmod), root_mod(rmod), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
{ {
...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager ...@@ -97,6 +103,8 @@ struct module_pm : module_pass_manager
virtual module* get_root_module() override virtual module* get_root_module() override
{ {
if(root_mod != nullptr)
return root_mod;
assert(prog); assert(prog);
return prog->get_main_module(); return prog->get_main_module();
} }
...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas ...@@ -140,7 +148,7 @@ void run_passes(program& prog, module_ref root_mod, const std::vector<pass>& pas
continue; continue;
if(not visited.insert(mod).second) if(not visited.insert(mod).second)
continue; continue;
module_pm mpm{mod, &trace}; module_pm mpm{mod, root_mod, &trace};
mpm.prog = &prog; mpm.prog = &prog;
auto parents = range(tree.equal_range(mod)); auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents); auto nparents = distance(parents);
...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -164,7 +172,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, &trace}.run_pass(p); module_pm{&mod, &mod, &trace}.run_pass(p);
} }
} }
......
...@@ -70,9 +70,8 @@ struct program_impl ...@@ -70,9 +70,8 @@ struct program_impl
{ {
// A map is used to keep references to modules of the program // A map is used to keep references to modules of the program
std::unordered_map<std::string, module> modules; std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
std::vector<context> contexts; std::vector<context> contexts;
std::vector<target> targets;
}; };
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); } program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
...@@ -96,14 +95,8 @@ void program::assign(const program& p) ...@@ -96,14 +95,8 @@ void program::assign(const program& p)
{ {
impl = std::make_unique<program_impl>(); impl = std::make_unique<program_impl>();
} }
else if(not impl->modules.empty())
{
impl->modules.clear();
}
impl->ctx = p.impl->ctx; *impl = *p.impl;
impl->target_name = p.impl->target_name;
impl->modules = p.impl->modules;
// build a map from old ins to new ins // build a map from old ins to new ins
// Build a map from old module to new module // Build a map from old module to new module
...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const ...@@ -166,7 +159,11 @@ std::vector<shape> program::get_output_shapes() const
return mm->get_output_shapes(); return mm->get_output_shapes();
} }
context& program::get_context() const { return impl->ctx; } context& program::get_context() const
{
assert(impl->contexts.size() == 1);
return impl->contexts.front();
}
instruction_ref program::validate() const instruction_ref program::validate() const
{ {
...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta ...@@ -217,7 +214,7 @@ target_assignments program::get_target_assignments(const std::vector<target>& ta
return p; return p;
} }
bool program::is_compiled() const { return not this->impl->target_name.empty(); } bool program::is_compiled() const { return not this->impl->contexts.empty(); }
void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts) void program::compile(const std::vector<target>& targets, std::vector<compile_options> compile_opts)
{ {
...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -299,24 +296,24 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() + MIGRAPHX_THROW("Dangling reference in module " + current_mod->name() +
" from instruction " + std::to_string(index)); " from instruction " + std::to_string(index));
} }
current_mod->finalize(this->impl->contexts[root_target_id]);
} }
} }
this->finalize();
} }
void program::compile(const target& t, compile_options options) void program::compile(const target& t, compile_options options)
{ {
// todo: combine with multi-target compile method // todo: combine with multi-target compile method
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->targets = {t};
this->impl->ctx = t.get_context(); this->impl->contexts = {t.get_context()};
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->contexts.front(), options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options) ...@@ -335,14 +332,14 @@ void program::compile(const target& t, compile_options options)
MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " + MIGRAPHX_THROW("Dangling reference in module " + mod->name() + " from instruction " +
std::to_string(index)); std::to_string(index));
} }
mod->finalize(this->impl->ctx); mod->finalize(this->impl->contexts);
} }
} }
void program::finalize() void program::finalize()
{ {
auto* mm = this->get_main_module(); auto* mm = this->get_main_module();
mm->finalize(this->impl->ctx); mm->finalize(this->impl->contexts);
} }
template <class T> template <class T>
...@@ -359,6 +356,31 @@ std::string classify(T x) ...@@ -359,6 +356,31 @@ std::string classify(T x)
} }
} }
void print_statistics(std::ostream& os, const argument& a)
{
a.visit(
[&](auto t) {
os << "Min value: " << *std::min_element(t.begin(), t.end()) << ", ";
os << "Max value: " << *std::max_element(t.begin(), t.end()) << ", ";
double num_elements = t.size();
auto mean = std::accumulate(t.begin(), t.end(), 0.0) / num_elements;
auto stddev = std::sqrt(
std::accumulate(t.begin(),
t.end(),
0.0,
[&](auto r, auto v) { return r + std::pow((v - mean), 2.0); }) /
num_elements);
os << "Mean: " << mean << ", ";
os << "StdDev: " << stddev << "\n";
},
[&](const auto& xs) {
for(const auto& x : xs)
{
print_statistics(os, x);
}
});
}
std::unordered_set<std::string> classify_argument(const argument& a) std::unordered_set<std::string> classify_argument(const argument& a)
{ {
std::unordered_set<std::string> result; std::unordered_set<std::string> result;
...@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a) ...@@ -404,16 +426,15 @@ void preview_argument(std::ostream& os, const argument& a)
template <class F> template <class F>
std::vector<argument> generic_eval(const module* mod, std::vector<argument> generic_eval(const module* mod,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
std::unordered_map<instruction_ref, argument> results, std::unordered_map<instruction_ref, argument> results,
F make_trace) F trace)
{ {
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
results.reserve(mod->size() * 2); results.reserve(mod->size() * 2);
std::vector<argument> values; std::vector<argument> values;
values.reserve(16); values.reserve(16);
auto trace = make_trace(mod);
for(auto ins : iterator_for(*mod)) for(auto ins : iterator_for(*mod))
{ {
assert(results.find(ins) == results.end()); assert(results.find(ins) == results.end());
...@@ -469,13 +490,18 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -469,13 +490,18 @@ std::vector<argument> generic_eval(const module* mod,
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
auto module_eval = [&](module_ref smod, auto module_eval = [&](module_ref smod,
const std::unordered_map<std::string, argument>& inputs) { const std::unordered_map<std::string, argument>& inputs) {
auto ssctx = ctx; return generic_eval(smod, ctx, inputs, results, trace);
return generic_eval(smod, ssctx, inputs, results, make_trace);
}; };
results.emplace(ins, trace(ins, [&] { results.emplace(
return ins->normalized_operator().compute( ins, trace(ins, [&] {
ctx, ins->get_shape(), values, mod_args, module_eval); auto op = ins->normalized_operator();
if(op.is_context_free())
return op.compute(ins->get_shape(), values, mod_args, module_eval);
if(ins->get_target_id() >= ctx.size())
MIGRAPHX_THROW("No context available for " + op.name());
return op.compute(
ctx[ins->get_target_id()], ins->get_shape(), values, mod_args, module_eval);
})); }));
} }
assert(results.find(ins) != results.end()); assert(results.find(ins) != results.end());
...@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -489,44 +515,25 @@ std::vector<argument> generic_eval(const module* mod,
template <class F> template <class F>
std::vector<argument> generic_eval(const program& p, std::vector<argument> generic_eval(const program& p,
context& ctx, std::vector<context>& ctx,
std::unordered_map<std::string, argument> params, std::unordered_map<std::string, argument> params,
F make_trace) F trace)
{ {
const module* mm = p.get_main_module(); const module* mm = p.get_main_module();
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, trace);
} }
std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& contexts = this->impl->contexts;
#ifndef NDEBUG
auto with_check_context = [&](auto f) {
return [=, &ctx](auto&&) {
auto sctx = std::make_shared<context>(ctx);
auto check_context = [=, &ctx](auto g) {
assert(is_shared(ctx, *sctx));
auto x = g();
*sctx = ctx;
return x;
};
return [=](auto&&... xs) { return f(xs..., check_context); };
};
};
#else
auto with_check_context = [](auto f) {
return [=](auto&&) {
return [=](auto&&... xs) { return f(xs..., [](auto g) { return g(); }); };
};
};
#endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret; std::vector<argument> ret;
if(exec_env.async) if(exec_env.async)
{ {
ctx.wait_for(exec_env.queue); assert(contexts.size() == 1);
contexts.front().wait_for(exec_env.queue);
} }
if(trace_level > 0) if(trace_level > 0)
...@@ -538,32 +545,42 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -538,32 +545,42 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
instruction::print(ss, x, ins_names); instruction::print(ss, x, ins_names);
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
ret = generic_eval(*this, auto& ctx = contexts[ins->get_target_id()];
ctx,
std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
auto result = check_context(f); auto result = f();
double t1 = t.record<milliseconds>(); double t1 = t.record<milliseconds>();
ctx.finish(); ctx.finish();
double t2 = t.record<milliseconds>(); double t2 = t.record<milliseconds>();
std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl; std::cout << "Time: " << t1 << "ms, " << t2 << "ms" << std::endl;
if(trace_level > 1 and ins->name().front() != '@' and if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load" and
ins->name() != "load" and not result.empty()) not result.empty())
{
migraphx::argument buffer;
try
{ {
target tgt = make_target(this->impl->target_name); const target& tgt = this->impl->targets.at(ins->get_target_id());
auto buffer = tgt.copy_from(result); buffer = tgt.copy_from(result);
}
catch(const migraphx::exception&)
{
// instruction was run on host then no need to copy buffer from target
buffer = result;
}
catch(...)
{
MIGRAPHX_THROW("MIGraphX program execution with MIGRAPHX_TRACE_EVAL failed.\n");
}
if(trace_level == 2) if(trace_level == 2)
{ {
std::cout << "Output has " std::cout << "Output has " << to_string_range(classify_argument(buffer))
<< to_string_range(classify_argument(buffer))
<< std::endl; << std::endl;
std::cout << "Output: "; std::cout << "Output: ";
preview_argument(std::cout, buffer); preview_argument(std::cout, buffer);
std::cout << std::endl; std::cout << std::endl;
print_statistics(std::cout, buffer);
} }
else else
{ {
...@@ -571,35 +588,36 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -571,35 +588,36 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
} }
} }
return result; return result;
})); });
} }
else else
{ {
ret = generic_eval(*this, ret = generic_eval(*this, contexts, std::move(params), [&](auto&&, auto f) { return f(); });
ctx,
std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f);
}));
} }
if(exec_env.async) if(exec_env.async)
{ {
ctx.finish_on(exec_env.queue); assert(contexts.size() == 1);
contexts.front().finish_on(exec_env.queue);
} }
return ret; return ret;
} }
const int program_file_version = 5; void program::finish() const
{
for(const auto& ctx : this->impl->contexts)
ctx.finish();
}
const int program_file_version = 6;
value program::to_value() const value program::to_value() const
{ {
value result; value result;
result["version"] = program_file_version; result["version"] = program_file_version;
result["target"] = this->impl->target_name; result["targets"] = migraphx::to_value(this->impl->targets);
if(not this->impl->target_name.empty()) result["contexts"] = migraphx::to_value(this->impl->contexts);
result["context"] = this->impl->ctx.to_value();
value module_vals = value::object{}; value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names; std::unordered_map<instruction_ref, std::string> names;
...@@ -728,12 +746,12 @@ void program::from_value(const value& v) ...@@ -728,12 +746,12 @@ void program::from_value(const value& v)
MIGRAPHX_THROW("Warning: Program version mismatch"); MIGRAPHX_THROW("Warning: Program version mismatch");
} }
this->impl->target_name = v.at("target").to<std::string>(); migraphx::from_value(v.at("targets"), this->impl->targets);
if(not this->impl->target_name.empty())
for(auto i : range(this->impl->targets.size()))
{ {
target t = make_target(this->impl->target_name); this->impl->contexts.push_back(this->impl->targets[i].get_context());
this->impl->ctx = t.get_context(); this->impl->contexts.back().from_value(v.at("contexts")[i]);
this->impl->ctx.from_value(v.at("context"));
} }
auto module_vals = v.at("modules"); auto module_vals = v.at("modules");
...@@ -754,6 +772,8 @@ void program::from_value(const value& v) ...@@ -754,6 +772,8 @@ void program::from_value(const value& v)
auto* mm = get_main_module(); auto* mm = get_main_module();
mod_from_val(mm, module_vals, map_insts, map_mods); mod_from_val(mm, module_vals, map_insts, map_mods);
// Finalize a compiled model
if(not this->impl->contexts.empty())
this->finalize(); this->finalize();
} }
...@@ -774,19 +794,19 @@ std::string perf_group(const operation& op) ...@@ -774,19 +794,19 @@ std::string perf_group(const operation& op)
void program::mark(const parameter_map& params, marker&& m) void program::mark(const parameter_map& params, marker&& m)
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Start marking // Start marking
m.mark_start(*this); m.mark_start(*this);
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
m.mark_start(ins); m.mark_start(ins);
result = f(); result = f();
m.mark_stop(ins); m.mark_stop(ins);
return result; return result;
})); });
m.mark_stop(*this); m.mark_stop(*this);
} }
...@@ -795,10 +815,10 @@ void program::perf_report(std::ostream& os, ...@@ -795,10 +815,10 @@ void program::perf_report(std::ostream& os,
parameter_map params, parameter_map params,
std::size_t batch) const std::size_t batch) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
// Run once by itself // Run once by itself
eval(params); eval(params);
ctx.finish(); this->finish();
// Run and time entire program // Run and time entire program
std::vector<double> total_vec; std::vector<double> total_vec;
total_vec.reserve(n); total_vec.reserve(n);
...@@ -806,28 +826,28 @@ void program::perf_report(std::ostream& os, ...@@ -806,28 +826,28 @@ void program::perf_report(std::ostream& os,
{ {
total_vec.push_back(time<milliseconds>([&] { total_vec.push_back(time<milliseconds>([&] {
eval(params); eval(params);
ctx.finish(); this->finish();
})); }));
} }
std::sort(total_vec.begin(), total_vec.end()); std::sort(total_vec.begin(), total_vec.end());
std::unordered_map<instruction_ref, std::vector<double>> ins_vec; std::unordered_map<instruction_ref, std::vector<double>> ins_vec;
// Fill the map // Fill the map
generic_eval(*this, ctx, params, always([&](auto ins, auto) { generic_eval(*this, ctx, params, [&](auto ins, auto) {
ins_vec[ins].reserve(n); ins_vec[ins].reserve(n);
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
// Run and time each instruction // Run and time each instruction
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
generic_eval(*this, ctx, params, always([&](auto ins, auto f) { generic_eval(*this, ctx, params, [&](auto ins, auto f) {
argument result; argument result;
ins_vec[ins].push_back(time<milliseconds>([&] { ins_vec[ins].push_back(time<milliseconds>([&] {
result = f(); result = f();
ctx.finish(); this->impl->contexts[ins->get_target_id()].finish();
})); }));
return result; return result;
})); });
} }
for(auto&& p : ins_vec) for(auto&& p : ins_vec)
std::sort(p.second.begin(), p.second.end()); std::sort(p.second.begin(), p.second.end());
...@@ -995,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const ...@@ -995,10 +1015,10 @@ void program::print_cpp(std::ostream& os) const
void program::dry_run(std::unordered_map<std::string, argument> params) const void program::dry_run(std::unordered_map<std::string, argument> params) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->contexts;
generic_eval(*this, ctx, std::move(params), always([](auto ins, auto&&...) { generic_eval(*this, ctx, std::move(params), [](auto ins, auto&&...) {
return argument{ins->get_shape(), nullptr}; return argument{ins->get_shape(), nullptr};
})); });
} }
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
......
...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const ...@@ -34,7 +34,7 @@ void promote_literals::apply(module_pass_manager& mpm) const
{ {
module& m = mpm.get_module(); module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module(); module_ref root_module = mpm.get_root_module();
if(m.name() == "main") if(m == *root_module)
return; return;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment