Unverified Commit fd3252dc authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Merge branch 'develop' into dot-add

parents 56615a84 8192f37f
......@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN PATH=/opt/cmake/bin:$PATH cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@02078ce236ad90e3aec04c0c770ef5bfc99e49c2
RUN cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@26a4b3cfc0a1a15181490f24ae461608fef1b04e -DBUILD_MIXR_TARGET=On
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
......
......@@ -80,7 +80,7 @@
"outputs": [],
"source": [
"if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16 --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
]
......
......@@ -88,6 +88,7 @@ add_library(migraphx
shape.cpp
simplify_algebra.cpp
simplify_reshapes.cpp
target_assignments.cpp
tmp_dir.cpp
value.cpp
verify_args.cpp
......
......@@ -93,9 +93,11 @@ static bool try_compute_shape(instruction_ref ins,
return try_compute_shape(ins, inputs, mods);
}
void eliminate_contiguous::apply(module& m) const
template <class F>
static void remove_contiguous(const std::string& op_name, module& m, F f)
{
std::vector<instruction_ref> const_instruction;
auto last = std::prev(m.end());
std::vector<instruction_ref> const_instructions;
for(auto ins : iterator_for(m))
{
......@@ -103,6 +105,12 @@ void eliminate_contiguous::apply(module& m) const
if(ins->name() == "@return")
continue;
if(ins != last and ins->outputs().empty())
continue;
if(not f(ins))
continue;
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
auto new_args = args;
......@@ -110,36 +118,46 @@ void eliminate_contiguous::apply(module& m) const
for(auto arg : ins->inputs())
{
if(arg->name() == op_name)
if(arg->name() != op_name)
continue;
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
else if(prev->can_eval())
{
const_instruction.push_back(arg);
}
const_instructions.push_back(arg);
}
}
}
// Perform evaluations in parallel
std::vector<argument> literals(const_instruction.size());
par_for(const_instruction.size(), 1, [&](const auto i) {
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instruction[i]->inputs().front();
auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instruction.size(); i++)
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instruction[i], l);
m.replace_instruction(const_instructions[i], l);
}
}
void eliminate_contiguous::apply(module& m) const
{
// Skip contiguous from splits first
remove_contiguous(op_name, m, [](auto ins) {
if(ins->name() != "slice")
return true;
return (ins->inputs().front()->outputs().size() == 1);
});
remove_contiguous(op_name, m, [](auto) { return true; });
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -142,7 +142,7 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
input_map[input] = map_ins[param];
}
}
pm->replace_return(pm->insert_module_instructions(last, xm, map_ins));
pm->replace_return(pm->insert_instructions(last, xm, map_ins));
return inputs;
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#define MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
#include <migraphx/support_metric.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct assignment_options
{
support_metric metric = support_metric::latency;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_RTGLIB_ASSIGNMENT_OPTIONS_HPP
......@@ -71,6 +71,11 @@ struct check_shapes
return end - begin;
}
/*!
* Check if the number of shape objects is equal to atleast one of the
* given sizes.
* \param ns template parameter pack of sizes to check against
*/
template <class... Ts>
const check_shapes& has(Ts... ns) const
{
......
......@@ -349,25 +349,27 @@ match::matcher_result find_match(module& modl, M&& m)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
if(trace > 1)
std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end())
return;
if(trace)
if(trace > 0)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
mod.debug_print(ins);
get_module(mod).debug_print(ins);
}
m.apply(mod, r);
match = true;
......@@ -376,10 +378,10 @@ void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
}
/// Find matches in a module
template <class... Ms>
void find_matches(module& mod, Ms&&... ms)
template <class Mod, class... Ms>
void find_matches(Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(mod))
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(mod, ins, ms...);
}
......
......@@ -120,9 +120,33 @@ struct module
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
std::vector<instruction_ref>
insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
......@@ -140,6 +164,10 @@ struct module
instruction_ref replace_return(std::vector<instruction_ref> args);
instruction_ref insert_literal(instruction_ref ins, literal l);
instruction_ref insert_parameter(instruction_ref ins, std::string name, shape s);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
......@@ -203,6 +231,8 @@ struct module
std::unique_ptr<module_impl> impl;
};
inline module& get_module(module& m) { return m; }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -32,7 +32,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
using module_ref = module*;
using const_module_ref = const module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -42,11 +42,12 @@ namespace op {
struct unsqueeze
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
......@@ -73,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size);
......@@ -80,16 +84,27 @@ struct unsqueeze
std::size_t p = 0;
for(auto i : range(new_size))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{
new_lens[i] = 1;
if(p == 0) // unsqueeze on the first axes
std::int64_t step = 1;
if(axis_idx < steps.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
new_strides[i] = old_lens[0] * old_strides[0];
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else // unsqueeze on middle or last axes
else
{
new_strides[i] = (p < old_strides.size()) ? old_strides[p - 1] : 1;
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
}
else
......
......@@ -55,8 +55,14 @@ inline std::vector<int64_t> sort_permutation(const Vector& data, Op op)
return result;
}
/*!
* Returns the permutation needed to apply to the shape to undo the current permutation
*/
std::vector<int64_t> invert_permutation(const std::vector<int64_t>& permutation);
/*!
* Finds the permutation most likely from a transpose operator that has been applied to the shape.
*/
std::vector<int64_t> find_permutation(const shape& s);
std::vector<int64_t> find_permutation(const std::vector<shape>& shapes);
......
......@@ -33,6 +33,8 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/target_assignments.hpp>
#include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
......@@ -84,6 +86,9 @@ struct program
instruction_ref validate() const;
target_assignments get_target_assignments(const std::vector<target>& targets,
assignment_options options = assignment_options{});
void compile(const target& t, compile_options options = compile_options{});
bool is_compiled() const;
......
......@@ -198,6 +198,12 @@ void transform(Range&& r, Iterator it, F f)
std::transform(r.begin(), r.end(), it, f);
}
template <class Range1, class Range2, class Iterator, class F>
void transform(Range1&& r1, Range2&& r2, Iterator it, F f)
{
std::transform(r1.begin(), r1.end(), r2.begin(), it, f);
}
template <class Range>
auto reverse(Range& r)
{
......
......@@ -82,6 +82,23 @@ struct shape
{
};
struct dynamic_dimension
{
std::size_t min = 0;
std::size_t max = 0;
std::size_t opt = 0;
template <class Self, class F>
static auto reflect(Self& self, F f);
bool is_fixed() const;
bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
};
static const std::vector<type_t>& types();
static std::string name(type_t t);
......@@ -92,6 +109,12 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
// Force all calls of the format `shape( type_t, { size_t compatibles } )` to map to
// shape(type_t, std::vector<std::size_t> l)
shape(type_t t, std::initializer_list<std::size_t> d);
shape(type_t t, std::vector<dynamic_dimension> dims);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{
......@@ -112,10 +135,44 @@ struct shape
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
/*!
* Return the number of elements in the tensor.
*/
std::size_t elements() const;
/*!
* Return the number of total bytes used for storage of the tensor data; includes subshapes.
* For dynamic shape, returns the maximum number of bytes presuming a packed shape.
*/
std::size_t bytes() const;
/*!
* Return the size of the type of the main shape.
* Returns 0 if there are subshapes.
*/
std::size_t type_size() const;
const std::vector<dynamic_dimension>& dyn_dims() const;
/*!
* Minimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> min_lens() const;
/*!
* Maximum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> max_lens() const;
/*!
* Optimum lengths for dynamic shape.
* lens() for fixed shape.
*/
std::vector<std::size_t> opt_lens() const;
/// Map multiple indices to space index
std::size_t index(std::initializer_list<std::size_t> l) const;
/// Map multiple indices to space index
......@@ -136,19 +193,27 @@ struct shape
std::vector<std::size_t> multi(std::size_t i) const;
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed (number of elements and buffer size the same) with no
/// padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const;
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;
/// Return true if the shape is dynamic
bool dynamic() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......@@ -191,6 +256,10 @@ struct shape
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
auto is_integral() const { return std::is_integral<type>{}; }
auto is_signed() const { return std::is_signed<type>{}; }
auto is_unsigned() const { return std::is_unsigned<type>{}; }
template <class U>
type* from(U* buffer, std::size_t n = 0) const
{
......@@ -248,6 +317,11 @@ struct shape
const std::vector<shape>& sub_shapes() const;
/*!
* Returns the number of elements in the data buffer.
* For a dynamic shape, returns the maximum number of elements of the data buffer and assumes it
* is packed.
*/
std::size_t element_space() const;
private:
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
enum class support_metric
{
latency,
throughput
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SUPPORT_METRIC_HPP
......@@ -37,6 +37,8 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -61,6 +63,13 @@ struct target
* @return The context to be used during compilation and execution.
*/
context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
/**
* @brief copy an argument to the current target.
*
......@@ -105,6 +114,12 @@ argument copy_from_target(T&, const argument& arg)
return arg;
}
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
{
return 0;
}
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
......@@ -117,6 +132,8 @@ struct target
//
context get_context() const;
// (optional)
float is_supported(instruction_ref ins, support_metric m) const;
// (optional)
argument copy_to(const argument& input) const;
// (optional)
argument copy_from(const argument& input) const;
......@@ -207,6 +224,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context();
}
float is_supported(instruction_ref ins, support_metric m) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().is_supported(ins, m);
}
argument copy_to(const argument& input) const
{
assert((*this).private_detail_te_handle_mem_var);
......@@ -242,11 +265,31 @@ struct target
virtual std::vector<pass> get_passes(context& ctx,
const compile_options& options) const = 0;
virtual context get_context() const = 0;
virtual float is_supported(instruction_ref ins, support_metric m) const = 0;
virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0;
};
template <class T>
static auto private_detail_te_default_is_supported(char,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
-> decltype(private_detail_te_self.is_supported(ins, m))
{
return private_detail_te_self.is_supported(ins, m);
}
template <class T>
static float private_detail_te_default_is_supported(float,
T&& private_detail_te_self,
instruction_ref ins,
support_metric m)
{
return target_is_supported(private_detail_te_self, ins, m);
}
template <class T>
static auto
private_detail_te_default_copy_to(char, T&& private_detail_te_self, const argument& input)
......@@ -329,6 +372,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); }
float is_supported(instruction_ref ins, support_metric m) const override
{
return private_detail_te_default_is_supported(char(0), private_detail_te_value, ins, m);
}
argument copy_to(const argument& input) const override
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
#include <unordered_map>
#include <migraphx/instruction_ref.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct target_assignments
{
void add_assignment(instruction_ref ins, const std::string& target);
auto begin() const { return assignments.cbegin(); }
auto end() const { return assignments.cend(); }
private:
std::unordered_map<instruction_ref, std::string> assignments;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ASSIGNMENT_HPP
......@@ -35,7 +35,7 @@ static void inline_submodule(module& m, instruction_ref ins, bool cond)
{
const auto& mod_inputs = ins->module_inputs();
module_ref smod = cond ? mod_inputs.at(0) : mod_inputs.at(1);
auto mod_outputs = m.insert_module_instructions(ins, smod);
auto mod_outputs = m.insert_instructions(ins, smod);
auto ins_outputs = ins->outputs();
assert(mod_outputs.size() >= ins_outputs.size());
......
......@@ -197,6 +197,62 @@ void module::assign(const module& m)
}
}
template <class Range>
static std::vector<instruction_ref>
insert_generic_instructions(module& m,
instruction_ref ins,
Range&& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
assert(m.has_instruction(ins) or is_end(ins, m.end()));
std::vector<instruction_ref> mod_outputs;
instruction_ref last;
for(instruction_ref sins : instructions)
{
last = sins;
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = m.add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = m.add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = m.add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
copy_ins = m.insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty() and instructions.begin() != instructions.end())
mod_outputs = {map_ins.at(last)};
return mod_outputs;
}
instruction_ref module::add_instruction(const operation& op, std::vector<instruction_ref> args)
{
return insert_instruction(impl->instructions.end(), op, std::move(args));
......@@ -335,61 +391,56 @@ instruction_ref module::move_instructions(instruction_ref src, instruction_ref d
return src;
}
std::vector<instruction_ref> module::insert_module_instructions(
instruction_ref ins, module_ref m, std::unordered_map<instruction_ref, instruction_ref> map_ins)
std::vector<instruction_ref>
module::add_instructions(const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
std::vector<instruction_ref> mod_outputs;
for(auto sins : iterator_for(*m))
{
if(contains(map_ins, sins))
continue;
instruction_ref copy_ins;
if(sins->name() == "@literal")
{
auto l = sins->get_literal();
copy_ins = this->add_literal(l);
}
else if(sins->name() == "@param")
{
auto&& name = any_cast<builtin::param>(sins->get_operator()).parameter;
auto s = sins->get_shape();
copy_ins = this->add_parameter(name, s);
}
else if(sins->name() == "@outline")
{
auto s = sins->get_shape();
copy_ins = this->add_outline(s);
}
else
{
auto mod_args = sins->module_inputs();
auto inputs = sins->inputs();
std::vector<instruction_ref> copy_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), copy_inputs.begin(), [&](auto i) {
return contains(map_ins, i) ? map_ins[i] : i;
});
return this->insert_instructions(this->end(), instructions, std::move(map_ins));
}
if(sins->name() == "@return")
{
mod_outputs = copy_inputs;
break;
}
std::vector<instruction_ref>
module::add_instructions(const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), m, std::move(map_ins));
}
copy_ins = this->insert_instruction(ins, sins->get_operator(), copy_inputs, mod_args);
}
map_ins[sins] = copy_ins;
}
if(mod_outputs.empty())
mod_outputs = {map_ins.at(std::prev(m->end()))};
return mod_outputs;
std::vector<instruction_ref>
module::add_instructions(instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return this->insert_instructions(this->end(), start, last, std::move(map_ins));
}
instruction_ref module::add_literal(literal l)
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const std::vector<instruction_ref>& instructions,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
impl->emplace_front(std::move(l));
return impl->instructions.begin();
return insert_generic_instructions(*this, ins, instructions, std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
const_module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
return insert_generic_instructions(*this, ins, iterator_for(*m), std::move(map_ins));
}
std::vector<instruction_ref>
module::insert_instructions(instruction_ref ins,
instruction_ref start,
instruction_ref last,
std::unordered_map<instruction_ref, instruction_ref> map_ins)
{
auto r = range(start, last);
return insert_generic_instructions(*this, ins, iterator_for(r), std::move(map_ins));
}
instruction_ref module::add_literal(literal l) { return insert_literal(begin(), std::move(l)); }
instruction_ref module::add_outline(const shape& s)
{
impl->push_front({builtin::outline{s}, s, {}});
......@@ -398,10 +449,7 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
return insert_parameter(begin(), std::move(name), std::move(s));
}
instruction_ref module::add_return(std::vector<instruction_ref> args)
......@@ -414,6 +462,20 @@ instruction_ref module::add_return(std::vector<instruction_ref> args)
return result;
}
instruction_ref module::insert_literal(instruction_ref ins, literal l)
{
impl->emplace(ins, std::move(l));
return std::prev(ins);
}
instruction_ref module::insert_parameter(instruction_ref ins, std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->insert(ins, {builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return std::prev(ins);
}
instruction_ref module::replace_return(std::vector<instruction_ref> args)
{
auto last = std::prev(this->end());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment