Commit baac1dab authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-host-lib

parents 830dff7a 77042e30
......@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
......@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
// If axes arg is not given, the default is all of them.
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
......@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
......@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
if(args[1]->get_shape().lens() != lens)
else if(std::none_of(
args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
if(args[1]->get_shape().lens() != lens)
{
args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
}
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include "memory_coloring_impl.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring_impl::run()
{
// calc implicit depdendencies
mod_implicit_deps = p_mod->calc_implicit_deps();
MIGRAPHX_DEBUG(dump("---Before memory coloring---"));
MIGRAPHX_DEBUG(dump_module());
build();
if(num_of_lives != 0)
{
MIGRAPHX_DEBUG(dump_intervals());
// Coloring
while(not alloc_queue.empty())
{
interval_ptr interval = alloc_queue.top();
allocate(interval);
alloc_queue.pop();
}
// rewrite happens after all modules are processed
rewrite();
if(enable_verify)
verify();
}
}
bool memory_coloring_impl::allocate(interval_ptr interval)
{
shape s = interval->result;
std::size_t size = s.bytes();
if(size == 0)
return false;
std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment;
int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
std::unordered_map<long long, live_range*> offset2_live;
offset2_live.clear();
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
long long offset = range->offset;
if(offset != invalid_offset)
{
conflict_queue.push(range);
if(offset2_live.find(offset) == offset2_live.end())
{
offset2_live[offset] = range;
}
else
{
live_range* prev = offset2_live[offset];
assert(prev->offset == offset);
if(prev->size < range->size)
offset2_live[offset] = range;
}
}
}
}
std::size_t offset = 0;
while(not conflict_queue.empty())
{
live_range* range = conflict_queue.top();
std::size_t iter_offset = range->offset;
if(offset > iter_offset)
{
offset = std::max(offset, iter_offset + range->size);
}
else if(offset2_live[iter_offset] == range)
{
if((iter_offset > offset) && (iter_offset - offset) >= size)
{
break;
}
offset = iter_offset + range->size;
}
// alignment
if((offset % element_size) != 0)
offset += (element_size - (offset % element_size));
conflict_queue.pop();
}
// when int8 type is used, the offset could be any number
// if not 4-byte aligned, miopen int8 convolution can crash
offset = (offset + 3) / 4 * 4;
segment.offset = offset;
MIGRAPHX_DEBUG(segment.dump());
required_bytes = std::max(required_bytes, offset + segment.size);
return true;
}
void memory_coloring_impl::build()
{
std::size_t num_of_instrs = p_mod->size();
if(num_of_instrs == 0)
return;
auto cur_points = num_of_instrs * 2;
instruction_ref iter = p_mod->end();
instruction_ref begin = p_mod->begin();
std::vector<instruction_ref> dead_instrs;
std::set<int> live_set;
// Build live intervals.
live_intervals.resize(num_of_instrs);
do
{
iter = std::prev(iter);
const instruction* p_iter = &(*iter);
interval_ptr def_interval = nullptr;
bool is_dead = false;
if(instr2_live.find(p_iter) != instr2_live.end())
{
def_interval = instr2_live[p_iter];
bool is_lit = is_literal(iter);
if(is_allocate(iter) or is_lit)
{
live_range& range = def_interval->segment;
def_interval->result = iter->get_shape();
def_interval->is_literal = is_lit;
range.begin = cur_points;
def_interval->def_point = cur_points;
range.size = (iter->get_shape()).bytes();
if(not is_lit or unify_literals)
alloc_queue.push(def_interval);
live_set.erase(range.vn);
}
}
else if(not is_param(iter) && not is_outline(iter) && not is_check_context(iter))
{
is_dead = true;
}
auto inputs = iter->inputs();
if(contains(mod_implicit_deps, iter))
{
const auto& impl_deps = mod_implicit_deps.at(iter);
inputs.insert(inputs.end(), impl_deps.begin(), impl_deps.end());
}
for(auto&& arg : inputs)
{
if(not p_mod->has_instruction(arg))
continue;
if(is_param(arg) or is_outline(arg))
{
if(is_output_param(arg))
is_dead = false;
if(def_interval != nullptr)
{
def_interval->is_live_on_entry = true;
}
continue;
}
const instruction* p_arg = &(*instruction::get_output_alias(arg));
if(instr2_live.find(p_arg) == instr2_live.end())
{
// First time see a use, create a live interval.
int id = num_of_lives++;
interval_ptr interval = &(live_intervals[id]);
interval->id = id;
interval->segment.end = cur_points;
interval->segment.vn = ++max_value_number;
interval->add_use(cur_points);
instr2_live[p_arg] = interval;
add_conflicts(live_set, max_value_number);
live_set.insert(max_value_number);
live_ranges[max_value_number] = &(interval->segment);
earliest_end_point = cur_points;
if(latest_end_point == -1)
latest_end_point = cur_points;
}
else
{
interval_ptr interval = instr2_live[p_arg];
interval->add_use(cur_points);
assert(live_set.find(interval->id) != live_set.end());
}
}
if(is_dead)
dead_instrs.push_back(iter);
cur_points -= 2;
} while(iter != begin);
}
void memory_coloring_impl::rewrite()
{
std::vector<std::size_t> dims;
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_mod->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_mod))
{
const instruction* p_iter = &(*ins);
if(instr2_live.find(p_iter) != instr2_live.end())
{
interval_ptr interval = instr2_live[p_iter];
if(interval->get_begin() == invalid_offset)
continue;
if(not unify_literals && interval->is_literal)
continue;
std::size_t offset = 0;
if(interval->get_offset() != invalid_offset)
{
offset = interval->get_offset();
}
else
{
assert(interval->result.bytes() == 0);
}
if(is_allocate(ins))
{
p_mod->replace_instruction(
ins,
make_op("load", {{"shape", to_value(ins->get_shape())}, {"offset", offset}}),
scratch_param);
}
}
}
MIGRAPHX_DEBUG(dump("---After rewrite---"));
MIGRAPHX_DEBUG(dump_module());
}
void memory_coloring_impl::verify()
{
if(num_of_lives > 0)
{
for(int i = 0; i < num_of_lives; ++i)
{
const live_interval& interval = live_intervals[i];
const live_range& segment = interval.segment;
if(segment.begin == invalid_offset)
{
// if(not interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
if(segment.offset == invalid_offset)
{
continue;
}
int vn = segment.vn;
if(conflict_table.find(vn) != conflict_table.end())
{
const std::set<int>& vn_set = conflict_table[vn];
for(const auto& iter : vn_set)
{
live_range* range = live_ranges[iter];
if(range->offset == invalid_offset)
continue;
if(not is_disjoin(*range, segment))
MIGRAPHX_THROW("range and segment is not disjoined");
}
}
}
}
}
#ifdef MIGRAPHX_DEBUG_OPT
void memory_coloring_impl::dump(const std::string& str) { std::cout << str << std::endl; }
void memory_coloring_impl::dump_module() { std::cout << *p_mod << std::endl; }
void memory_coloring_impl::dump_intervals()
{
if(num_of_lives > 0)
{
std::cout << "---live intervals ---" << std::endl;
for(int i = 0; i < num_of_lives; ++i)
{
live_interval& interval = live_intervals[i];
interval.dump();
}
std::cout << "---conflict table---" << std::endl;
for(int i = 0; i <= max_value_number; ++i)
{
std::cout << " segment:" << i;
std::cout << " =>";
const std::set<int>& table = conflict_table[i];
for(const auto& iter : table)
{
std::cout << (iter) << ",";
}
}
std::cout << std::endl;
}
}
// map liveness tracking point to instruction enum.
static int get_ins_enum(int x)
{
if(x > 0)
{
return (x / 2) - 1;
}
else
return invalid_offset;
}
void live_range::dump()
{
std::cout << " segment:" << vn;
std::cout << " [" << get_ins_enum(begin) << ", " << get_ins_enum(end) << "]";
if(offset != invalid_offset)
{
std::cout << " mem:";
std::cout << " [" << offset << "," << offset + size - 1 << "]";
}
std::cout << std::endl;
}
void live_interval::dump()
{
std::cout << "id:" << id;
segment.dump();
std::cout << " uses:";
for(const auto& iter : use_points)
{
std::cout << " " << get_ins_enum(iter) << ",";
}
std::cout << " def:";
std::cout << " " << get_ins_enum(def_point);
if(is_literal)
std::cout << " literal";
std::cout << " " << result;
std::cout << std::endl;
}
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_IMPL_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/pass_config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/config.hpp>
#include <set>
#include <list>
#include <vector>
#include <queue>
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static const std::size_t invalid_offset = std::numeric_limits<std::size_t>::max();
struct live_range
{
std::size_t begin; // begin point in the instruction stream.
std::size_t end; // end point in the instruction stream.
std::size_t offset; // offset to base pointer of allocated memory trunk.
std::size_t vn; // value number that identifies this live_range.
std::size_t size; // size of required memory in bytes
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
};
struct live_interval
{
live_interval() : segment({invalid_offset, invalid_offset, invalid_offset, invalid_offset, 0})
{
}
void add_use(std::size_t use) { use_points.push_front(use); }
std::size_t get_begin() const { return segment.begin; }
std::size_t get_end() const { return segment.end; }
long long get_offset() const { return segment.offset; }
#ifdef MIGRAPHX_DEBUG_OPT
void dump();
#endif
live_range segment;
std::size_t id = invalid_offset;
std::list<std::size_t> use_points{};
std::size_t def_point = invalid_offset;
shape result{};
bool is_literal = false;
bool is_live_on_entry = false;
};
using interval_ptr = live_interval*;
struct memory_coloring_impl
{
memory_coloring_impl(module* p, std::string alloc_op, bool p_verify)
: p_mod(p), allocation_op(std::move(alloc_op)), enable_verify(p_verify)
{
}
bool allocate(interval_ptr);
void add_conflicts(const std::set<int>& live_set, int val)
{
for(const auto& iter : live_set)
{
conflict_table[iter].insert(val);
conflict_table[val].insert(iter);
}
}
void build();
void run();
void rewrite();
private:
static bool is_param(const instruction_ref ins) { return ins->name() == "@param"; }
static bool is_output_param(const instruction_ref ins)
{
if(not is_param(ins))
return false;
auto param_name = any_cast<builtin::param>(ins->get_operator()).parameter;
return contains(param_name, "#output_");
}
bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins)
{
return ins->name() == "check_context";
}
static bool is_disjoin(const live_range& range1, const live_range& range2)
{
if((range1.size == 0) or (range2.size == 0))
return false;
auto end1 = range1.offset + range1.size - 1;
auto end2 = range2.offset + range2.size - 1;
return ((end1 < range2.offset) or (end2 < range1.offset));
}
void verify();
#ifdef MIGRAPHX_DEBUG_OPT
void dump(const std::string&);
void dump_module();
void dump_intervals();
#endif
struct ordering
{
bool operator()(const interval_ptr& i1, const interval_ptr& i2) const
{
auto len1 = i1->get_end() - i1->get_begin();
auto len2 = i2->get_end() - i2->get_begin();
if(len1 != len2)
{
return (len1 < len2);
}
else if(i1->result.bytes() != i2->result.bytes())
{
return (i1->result.bytes() < i2->result.bytes());
}
else
{
return i1->id > i2->id;
}
}
bool operator()(const live_range* i1, const live_range* i2) const
{
return (i1->offset > i2->offset);
}
};
module* p_mod;
std::unordered_map<const instruction*, interval_ptr> instr2_live;
// universe of live intervals.
std::vector<live_interval> live_intervals = {};
// Map live range value number to live range.
std::unordered_map<int, live_range*> live_ranges = {};
// Map live range value number to a set of conflicting live ranges' value numbers.
std::unordered_map<int, std::set<int>> conflict_table = {};
// Priority queue for coloring.
std::priority_queue<interval_ptr, std::vector<interval_ptr>, ordering> alloc_queue{};
int num_of_lives = 0;
int max_value_number = -1;
std::size_t required_bytes = 0;
// The earliest program point where an live interval ends.
int earliest_end_point = -1;
// The latest program point where an live interval ends.
int latest_end_point = -1;
// Whether to unify literals into coloring.
bool unify_literals = false;
std::string allocation_op{};
bool enable_verify;
ins_dep_map mod_implicit_deps;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -86,14 +86,24 @@ struct module_pm : module_pass_manager
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override
{
assert(prog);
return prog->get_main_module();
}
virtual void run_pass(const pass& p) override
{
trace("Pass: ", p.name());
assert(mod);
assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{}))
......
......@@ -38,27 +38,42 @@ std::function<void(const char*)> redirect_to(std::ostream& os)
return [&](const char* x) { os << x; };
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
template <class F>
int exec(const std::string& cmd, const char* type, F f)
{
int ec = 0;
if(enabled(MIGRAPHX_TRACE_CMD_EXECUTE{}))
std::cout << cmd << std::endl;
auto closer = [&](FILE* stream) {
auto status = pclose(stream);
ec = WIFEXITED(status) ? 0 : WEXITSTATUS(status); // NOLINT
ec = WIFEXITED(status) ? WEXITSTATUS(status) : 0; // NOLINT
};
{
// TODO: Use execve instead of popen
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), "r"), closer); // NOLINT
std::unique_ptr<FILE, decltype(closer)> pipe(popen(cmd.c_str(), type), closer); // NOLINT
if(not pipe)
MIGRAPHX_THROW("popen() failed: " + cmd);
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr)
std_out(buffer.data());
f(pipe.get());
}
return ec;
}
int exec(const std::string& cmd, const std::function<void(const char*)>& std_out)
{
return exec(cmd, "r", [&](FILE* f) {
std::array<char, 128> buffer;
while(fgets(buffer.data(), buffer.size(), f) != nullptr)
std_out(buffer.data());
});
}
int exec(const std::string& cmd, std::function<void(process::writer)> std_in)
{
return exec(cmd, "w", [&](FILE* f) {
std_in([&](const char* buffer, std::size_t n) { std::fwrite(buffer, 1, n, f); });
});
}
struct process_impl
{
std::string command{};
......@@ -72,6 +87,15 @@ struct process_impl
result += command;
return result;
}
template <class... Ts>
void check_exec(Ts&&... xs) const
{
int ec = migraphx::exec(std::forward<Ts>(xs)...);
if(ec != 0)
MIGRAPHX_THROW("Command " + get_command() + " exited with status " +
std::to_string(ec));
}
};
process::process(const std::string& cmd) : impl(std::make_unique<process_impl>())
......@@ -95,12 +119,11 @@ process& process::cwd(const fs::path& p)
return *this;
}
void process::exec()
void process::exec() { impl->check_exec(impl->get_command(), redirect_to(std::cout)); }
void process::write(std::function<void(process::writer)> pipe_in)
{
auto ec = migraphx::exec(impl->get_command(), redirect_to(std::cout));
if(ec != 0)
MIGRAPHX_THROW("Command " + impl->get_command() + " exited with status " +
std::to_string(ec));
impl->check_exec(impl->get_command(), std::move(pipe_in));
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{
......@@ -333,7 +331,8 @@ std::vector<argument> generic_eval(const module* mod,
MIGRAPHX_THROW("Parameter not found: " + param_name);
auto param = params[param_name];
// TODO: may want to check correct number of dimensions and/or was within bounds
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
if(not ins->get_shape().any_of_dynamic() and
param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name +
......@@ -381,7 +380,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic())
if(not ins->get_shape().any_of_dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/promote_literals.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void promote_literals::apply(module_pass_manager& mpm) const
{
module& m = mpm.get_module();
module_ref root_module = mpm.get_root_module();
if(m.name() == "main")
return;
for(auto ins : iterator_for(m))
{
if(ins->name() == "@literal")
{
auto new_lit = root_module->add_literal(ins->get_literal());
for(auto out_ins : ins->outputs())
{
out_ins->replace_argument(out_ins, ins, new_lit);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -27,11 +27,14 @@
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/env.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins)
{
if(ins->name() == "contiguous")
......@@ -44,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false;
}
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{
......@@ -54,14 +57,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
if(is_const(i) and i != last)
const bool is_const = is_const_ins(i);
if(is_const and i != last)
continue;
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
if(i == last and is_const)
{
const_instrs.insert(i);
}
else
{
std::copy_if(i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) {
return is_const_ins(ins) and ins->name() != "@literal";
});
}
}
// Compute literals in parallel
......@@ -76,6 +88,19 @@ void propagate_constant::apply(module& m) const
{
if(not literals[i].empty())
{
if(enabled(MIGRAPHX_TRACE_PROPAGATE_CONSTANT{}))
{
std::cout << "Constant replace: " << std::endl;
std::vector<instruction_ref> inss;
fix([&](auto self, auto ins) {
if(contains(inss, ins))
return;
for(auto input : ins->inputs())
self(input);
inss.push_back(ins);
})(const_instrs_vec[i]);
m.debug_print(inss);
}
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
......
......@@ -35,7 +35,6 @@
#include <migraphx/stringutils.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
......@@ -63,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx {
migraphx::value to_value(py::kwargs kwargs);
......@@ -95,6 +95,10 @@ void visit_py(T x, F f)
{
f(x.template cast<std::string>());
}
else if(py::isinstance<migraphx::shape::dynamic_dimension>(x))
{
f(migraphx::to_value(x.template cast<migraphx::shape::dynamic_dimension>()));
}
else
{
MIGRAPHX_THROW("VISIT_PY: Unsupported data type!");
......@@ -165,7 +169,10 @@ template <class T>
py::buffer_info to_buffer_info(T& x)
{
migraphx::shape s = x.get_shape();
auto strides = s.strides();
assert(s.type() != migraphx::shape::tuple_type);
if(s.dynamic())
MIGRAPHX_THROW("MIGRAPHX PYTHON: dynamic shape argument passed to to_buffer_info");
auto strides = s.strides();
std::transform(
strides.begin(), strides.end(), strides.begin(), [&](auto i) { return i * s.type_size(); });
py::buffer_info b;
......@@ -177,7 +184,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<bool>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -186,7 +193,7 @@ py::buffer_info to_buffer_info(T& x)
b = py::buffer_info(x.data(),
as.size(),
py::format_descriptor<decltype(as())>::format(),
s.lens().size(),
s.ndim(),
s.lens(),
strides);
}
......@@ -236,10 +243,18 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
if(v.contains("dyn_dims"))
{
auto dyn_dims =
migraphx::from_value<std::vector<migraphx::shape::dynamic_dimension>>(
v.at("dyn_dims"));
return migraphx::shape(t, dyn_dims);
}
auto lens = v.get<std::size_t>("lens", {1});
if(v.contains("strides"))
return migraphx::shape(t, lens, v.at("strides").to_vector<std::size_t>());
......@@ -249,19 +264,34 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("type", &migraphx::shape::type)
.def("lens", &migraphx::shape::lens)
.def("strides", &migraphx::shape::strides)
.def("ndim", &migraphx::shape::ndim)
.def("elements", &migraphx::shape::elements)
.def("bytes", &migraphx::shape::bytes)
.def("type_string", &migraphx::shape::type_string)
.def("type_size", &migraphx::shape::type_size)
.def("dyn_dims", &migraphx::shape::dyn_dims)
.def("packed", &migraphx::shape::packed)
.def("transposed", &migraphx::shape::transposed)
.def("broadcasted", &migraphx::shape::broadcasted)
.def("standard", &migraphx::shape::standard)
.def("scalar", &migraphx::shape::scalar)
.def("dynamic", &migraphx::shape::dynamic)
.def("__eq__", std::equal_to<migraphx::shape>{})
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::shape::dynamic_dimension>(shape_cls, "dynamic_dimension")
.def(py::init<>())
.def(py::init<std::size_t, std::size_t>())
.def(py::init<std::size_t, std::size_t, std::set<std::size_t>>())
.def_readwrite("min", &migraphx::shape::dynamic_dimension::min)
.def_readwrite("max", &migraphx::shape::dynamic_dimension::max)
.def_readwrite("optimals", &migraphx::shape::dynamic_dimension::optimals)
.def("is_fixed", &migraphx::shape::dynamic_dimension::is_fixed);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) {
......@@ -283,7 +313,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
......@@ -329,15 +361,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled)
.def(
"compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) {
[](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
options.offload_copy = offload_copy;
options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options);
},
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
py::arg("offload_copy") = true,
py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def(
"create_module",
......@@ -428,13 +466,18 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
......@@ -442,8 +485,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
......@@ -452,20 +498,28 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx_buffer",
[](const std::string& onnx_buffer,
unsigned int default_dim_value,
migraphx::shape::dynamic_dimension default_dyn_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
options.map_input_dims = map_input_dims;
options.map_dyn_input_dims = map_dyn_input_dims;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx_buffer(onnx_buffer, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("default_dim_value") = 0,
py::arg("default_dyn_dim_value") = migraphx::shape::dynamic_dimension{1, 1},
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("map_dyn_input_dims") =
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
......
......@@ -33,7 +33,17 @@ std::unordered_map<std::string, operation>& op_map()
static std::unordered_map<std::string, operation> m; // NOLINT
return m;
}
void register_op_init() { (void)op_map(); }
void register_op(const operation& op) { op_map()[op.name()] = op; }
void unregister_op(const std::string& op_name)
{
assert(op_map().count(op_name));
op_map().erase(op_name);
}
operation load_op(const std::string& name)
{
return at(op_map(), name, "Operator not found: " + name);
......
......@@ -21,26 +21,48 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <string>
#include <unordered_map>
#include <migraphx/register_target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dynamic_loader.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void store_target_lib(const dynamic_loader& lib)
{
static std::vector<dynamic_loader> target_loader;
target_loader.emplace_back(lib);
}
std::unordered_map<std::string, target>& target_map()
{
static std::unordered_map<std::string, target> m; // NOLINT
return m;
}
void register_target_init() { (void)target_map(); }
void unregister_target(const std::string& name)
{
assert(target_map().count(name));
target_map().erase(name);
}
void register_target(const target& t) { target_map()[t.name()] = t; }
target make_target(const std::string& name)
{
if(not contains(target_map(), name))
{
std::string target_name = "libmigraphx_" + name + ".so";
store_target_lib(dynamic_loader(target_name));
}
const auto it = target_map().find(name);
if(it == target_map().end())
{
MIGRAPHX_THROW("Requested target '" + name + "' is not enabled or not supported");
MIGRAPHX_THROW("Requested target '" + name + "' is not loaded or not supported");
}
return it->second;
}
......
......@@ -104,19 +104,16 @@ void replace_allocate::apply(module& m) const
continue;
auto s = ins->get_shape();
if(not main_offload_copy and model.needs_out_params() and contains(mod_output_names, ins))
{
auto out_param = m.add_parameter(mod_output_names[ins], s);
m.replace_instruction(ins, out_param);
continue;
}
m.replace_instruction(
ins,
m.insert_instruction(ins,
make_op(model.name(), migraphx::value{{"shape", to_value(s)}})));
else
{
m.replace_instruction(ins,
make_op(model.name(), migraphx::value{{"shape", to_value(s)}}));
}
}
}
......
......@@ -327,10 +327,10 @@ struct stream_info
return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
return all_of(select(ins), [&](auto i) {
if(iweights.at(i) == 0)
return self(i);
else
if(has_stream(i))
return f(this->get_stream(i));
else
return self(i);
});
})(start);
};
......
......@@ -74,13 +74,23 @@ struct shape_impl
shape_impl(shape::type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
std::vector<std::set<std::size_t>> optimals_list)
: m_type(t)
{
assert(mins.size() == maxes.size() and maxes.size() == opts.size());
for(size_t i = 0; i < mins.size(); ++i)
if(optimals_list.empty())
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], opts[i]});
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i]});
}
}
else
{
assert(mins.size() == maxes.size() and maxes.size() == optimals_list.size());
for(size_t i = 0; i < mins.size(); ++i)
{
m_dyn_dims.push_back(shape::dynamic_dimension{mins[i], maxes[i], optimals_list[i]});
}
}
}
......@@ -147,7 +157,7 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.min; });
[](const shape::dynamic_dimension& x) { return x.min; });
return ret;
}
......@@ -157,19 +167,20 @@ struct shape_impl
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.max; });
[](const shape::dynamic_dimension& x) { return x.max; });
return ret;
}
std::vector<std::size_t> opt_lens() const
std::vector<std::set<std::size_t>> opt_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
std::vector<std::set<std::size_t>> ret(m_dyn_dims.size());
std::transform(m_dyn_dims.cbegin(),
m_dyn_dims.cend(),
ret.begin(),
[](shape::dynamic_dimension x) { return x.opt; });
[](const shape::dynamic_dimension& x) { return x.optimals; });
return ret;
}
// Does the shape skip over elements?
bool skips() const
{
......@@ -240,8 +251,9 @@ shape::shape(type_t t, std::vector<shape::dynamic_dimension> dims)
shape::shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts)
: impl(std::make_shared<shape_impl>(t, std::move(mins), std::move(maxes), std::move(opts)))
std::vector<std::set<std::size_t>> optimals_list)
: impl(std::make_shared<shape_impl>(
t, std::move(mins), std::move(maxes), std::move(optimals_list)))
{
}
......@@ -349,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std::vector<std::size_t> shape::multi(std::size_t i) const
std::vector<std::size_t> shape::multi(std::size_t idx) const
{
assert(this->standard());
assert(idx < elements());
std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size());
multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{
assert(this->standard());
size_t tidx = idx;
(void)end;
assert(idx < elements());
assert(lens().size() <= (end - start));
std::transform(strides().begin(),
strides().end(),
lens().begin(),
start,
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
for(size_t ii = lens().size() - 1; ii > 0; ii--)
{
*(start + ii) = tidx % lens()[ii];
tidx = tidx / lens()[ii];
}
*start = tidx;
}
bool shape::packed() const
......@@ -469,12 +478,44 @@ shape shape::with_type(type_t t) const
shape shape::to_dynamic() const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[](auto s) { return s.to_dynamic(); });
return {subs};
}
if(this->dynamic())
{
return *this;
}
std::vector<std::size_t> zeroes(this->ndim(), 0);
return {type(), lens(), lens(), zeroes};
return {type(), lens(), lens(), {}};
}
shape shape::to_static(std::size_t x) const
{
if(not sub_shapes().empty())
{
std::vector<shape> subs;
std::transform(sub_shapes().cbegin(),
sub_shapes().cend(),
std::back_inserter(subs),
[&](auto s) { return s.to_static(x); });
return {subs};
}
if(not this->dynamic())
{
return *this;
}
auto static_lens = this->max_lens();
std::transform(static_lens.begin(),
static_lens.end(),
this->dyn_dims().cbegin(),
static_lens.begin(),
[&](auto sl, auto dd) { return dd.is_fixed() ? sl : x; });
return {type(), static_lens};
}
std::size_t shape::element_space() const { return impl->element_space(); }
......@@ -483,6 +524,17 @@ std::string shape::type_string() const { return name(this->type()); }
bool shape::dynamic() const { return not impl->m_dyn_dims.empty(); }
bool shape::any_of_dynamic() const
{
if(this->dynamic())
{
return true;
}
return std::any_of(this->sub_shapes().cbegin(), this->sub_shapes().cend(), [](auto s) {
return s.any_of_dynamic();
});
}
const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return impl->m_dyn_dims; }
std::vector<std::size_t> shape::min_lens() const
......@@ -495,23 +547,22 @@ std::vector<std::size_t> shape::max_lens() const
return this->dynamic() ? impl->max_lens() : this->lens();
}
std::vector<std::size_t> shape::opt_lens() const
{
return this->dynamic() ? impl->opt_lens() : this->lens();
}
std::vector<std::set<std::size_t>> shape::opt_lens() const { return impl->opt_lens(); }
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
bool shape::dynamic_dimension::has_optimal() const { return opt != 0; }
bool shape::dynamic_dimension::has_optimal() const { return not optimals.empty(); }
shape::dynamic_dimension& shape::dynamic_dimension::operator+=(const std::size_t& x)
{
this->min += x;
this->max += x;
if(this->opt != 0)
{
this->opt += x;
};
std::set<std::size_t> new_optimals;
std::transform(this->optimals.begin(),
this->optimals.end(),
std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) { return (opt + x); });
this->optimals = new_optimals;
return *this;
}
......@@ -521,19 +572,23 @@ shape::dynamic_dimension& shape::dynamic_dimension::operator-=(const std::size_t
assert(this->max >= x);
this->min -= x;
this->max -= x;
if(this->opt != 0)
{
assert(this->opt >= x);
this->opt -= x;
}
std::set<std::size_t> new_optimals;
std::transform(this->optimals.begin(),
this->optimals.end(),
std::inserter(new_optimals, new_optimals.begin()),
[&x](const auto& opt) {
assert(opt >= x);
return (opt - x);
});
this->optimals = new_optimals;
return *this;
}
bool operator==(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
{
// don't check opt if both are fixed
// don't check optimals if both are fixed
return (x.min == y.min and x.max == y.max and
((x.is_fixed() and y.is_fixed()) or (x.opt == y.opt)));
((x.is_fixed() and y.is_fixed()) or (x.optimals == y.optimals)));
}
bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimension& y)
......@@ -542,7 +597,7 @@ bool operator!=(const shape::dynamic_dimension& x, const shape::dynamic_dimensio
}
std::ostream& operator<<(std::ostream& os, const shape::dynamic_dimension& x)
{
os << "[" << x.min << ", " << x.max << ", " << x.opt << "]";
os << "[ " << x.min << ", " << x.max << ", {" << migraphx::to_string_range(x.optimals) << "} ]";
return os;
}
......@@ -651,12 +706,10 @@ void migraphx_from_value(const value& v, shape& s)
{
auto v_dd = v.at("dynamic_dimensions");
std::vector<shape::dynamic_dimension> dyn_dims(v.at("dynamic_dimensions").size());
std::transform(v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](migraphx::value x) {
auto x_min = x.at("min").template to<size_t>();
auto x_max = x.at("max").template to<size_t>();
auto x_opt = x.at("opt").template to<size_t>();
return shape::dynamic_dimension{x_min, x_max, x_opt};
});
std::transform(
v_dd.begin(), v_dd.end(), dyn_dims.begin(), [](const migraphx::value& x) {
return from_value<shape::dynamic_dimension>(x);
});
s = shape{shape::parse_type(t), dyn_dims};
}
......
......@@ -204,7 +204,137 @@ struct find_mul_slice_conv
}
};
// a * (x + b) => a * x + a * b
struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = dot_ins->inputs()[0];
auto b_ins = dot_ins->inputs()[1];
auto c_ins = r.instructions["c"];
const auto& c_strides = c_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(c_strides.begin(), c_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
auto add_mul_const = [&](instruction_ref x_ins) {
if(not x_ins->can_eval())
return m.end();
auto broadcast_v = c_ins->get_operator().to_value();
broadcast_v["out_lens"] = x_ins->get_shape().lens();
auto cb_ins =
m.insert_instruction(ins, make_op(c_ins->name(), broadcast_v), c_ins->inputs());
return m.insert_instruction(ins, make_op("mul"), x_ins, cb_ins);
};
if(c_strides.back() == 1)
{
b_ins = add_mul_const(b_ins);
}
else if(c_strides[c_strides.size() - 2] == 1)
{
a_ins = add_mul_const(a_ins);
}
else if(c_ins->get_shape().scalar())
{
if(a_ins->can_eval())
a_ins = add_mul_const(a_ins);
else
b_ins = add_mul_const(b_ins);
}
else
{
return;
}
if(contains({a_ins, b_ins}, m.end()))
return;
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
struct find_dot_mul
{
auto matcher() const
{
auto const_broadcast = match::name("broadcast", "multibroadcast")(match::is_constant());
auto mul = match::name("mul")(
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
auto d_ins = r.instructions["d"];
auto c_ins = r.instructions["c"];
auto z_ins = r.instructions["z"];
const auto& d_strides = d_ins->get_shape().strides();
// There should only be one stride that is not zero
if(std::count_if(d_strides.begin(), d_strides.end(), [](auto s) { return s != 0; }) > 1)
return;
if(not d_ins->get_shape().scalar())
{
if(d_strides.back() == 1 and not b_ins->can_eval())
return;
if(d_strides[d_strides.size() - 2] == 1 and not a_ins->can_eval())
return;
}
auto broadcast_v = d_ins->get_operator().to_value();
auto c_lens = c_ins->get_shape().lens();
std::vector<int64_t> permutation(c_lens.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation.back(), permutation[permutation.size() - 2]);
c_lens = reorder_dims(c_lens, permutation);
broadcast_v["out_lens"] = c_lens;
auto db_ins =
m.insert_instruction(ins, make_op(d_ins->name(), broadcast_v), d_ins->inputs());
auto db_transpose_ins =
m.insert_instruction(ins, make_op("transpose", {{"permutation", permutation}}), db_ins);
auto cd_ins = m.insert_instruction(ins, make_op("mul"), c_ins, db_transpose_ins);
if(c_ins == b_ins)
{
a_ins = z_ins;
b_ins = cd_ins;
}
else
{
a_ins = cd_ins;
b_ins = z_ins;
}
m.replace_instruction(ins, make_op("dot"), a_ins, b_ins);
}
};
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct find_mul_add
{
auto matcher() const
......@@ -356,30 +486,118 @@ struct find_inner_broadcast
{
auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
static auto non_scalar_op(const std::string& name)
{
return [=](instruction_ref ins) {
if(ins->get_shape().scalar())
return false;
return ins->name() == name;
};
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcasts = ins->inputs();
if(broadcasts.empty())
return;
bool mixed_broadcasts = any_of(broadcasts, non_scalar_op("broadcast")) and
any_of(broadcasts, non_scalar_op("multibroadcast"));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if(mixed_broadcasts and any_of(broadcasts, [&](instruction_ref i) {
if(i->get_shape().scalar())
return false;
if(i->name() == "multibroadcast")
return false;
auto input = i->inputs().at(0);
const auto& lens = input->get_shape().lens();
return std::count_if(lens.begin(), lens.end(), [&](std::size_t d) {
return d == 1;
}) < (lens.size() - 1);
}))
return;
std::vector<instruction_ref> inputs;
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
}))
return;
auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
[&](instruction_ref i) {
auto input = i->inputs().front();
if(mixed_broadcasts and not i->get_shape().scalar() and
i->get_shape().lens().size() > 1)
return m.insert_instruction(i, make_op("squeeze"), input);
return input;
});
std::sort(broadcasts.begin(), broadcasts.end(), by(std::less<>{}, [](instruction_ref i) {
if(i->get_shape().scalar())
return 2;
else if(i->name() == "broadcast")
return 0;
if(i->name() == "multibroadcast")
return 1;
return 3;
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
};
struct find_dot_broadcast
{
auto matcher() const
{
return match::name("dot")(match::all_of[match::inputs()](match::broadcast()));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a = ins->inputs()[0];
auto b = ins->inputs()[1];
if(a->get_operator().name() != b->get_operator().name())
return;
if(ins->get_shape().lens().size() < 3)
return;
auto nbatch_axes = ins->get_shape().lens().size() - 2;
const auto& a_strides = a->get_shape().strides();
const auto& b_strides = b->get_shape().strides();
// Find leading batch axes that are broadcasted
auto p =
std::mismatch(a_strides.begin(),
a_strides.begin() + nbatch_axes,
b_strides.begin(),
b_strides.begin() + nbatch_axes,
[](auto astride, auto bstride) { return astride == 0 and bstride == 0; });
auto naxes = p.first - a_strides.begin();
assert(naxes <= nbatch_axes);
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), 0);
auto insert_broadcast = [&](instruction_ref b_ins) -> instruction_ref {
auto input = b_ins->inputs()[0];
std::vector<std::size_t> lens(b_ins->get_shape().lens().begin() + naxes,
b_ins->get_shape().lens().end());
if(b_ins->name() == "multibroadcast")
{
return m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), input);
}
else if(b_ins->name() == "broadcast")
{
auto v = b_ins->get_operator().to_value();
auto axis = v.at("axis").to<std::size_t>() - naxes;
return m.insert_instruction(
ins, make_op("broadcast", {{"axis", axis}, {"out_lens", lens}}), input);
}
assert(false);
return m.end();
};
auto a1 = insert_broadcast(a);
auto b1 = insert_broadcast(b);
auto dot = m.insert_instruction(ins, make_op("dot"), a1, b1);
auto broadcast = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", ins->get_shape().lens()}}), dot);
m.replace_instruction(ins, broadcast);
}
};
......@@ -388,7 +606,8 @@ struct find_concat_op
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](
match::any_of(match::pointwise(), match::name("broadcast")), match::used_once()));
match::any_of(match::pointwise(), match::name("broadcast", "multibroadcast")),
match::used_once()));
}
template <class Iterator>
......@@ -407,7 +626,8 @@ struct find_concat_op
static bool is_valid_op(const operation& op)
{
return op.name() == "broadcast" or op.attributes().contains("pointwise");
return contains({"broadcast", "multibroadcast"}, op.name()) or
op.attributes().contains("pointwise");
}
void apply(module& m, const match::matcher_result& r) const
......@@ -435,6 +655,16 @@ struct find_concat_op
op = b;
iaxis = 0;
}
else if(op.name() == "multibroadcast")
{
shape bshape = (*start)->get_shape();
auto input = (*start)->inputs()[0];
if(iaxis >= bshape.strides().size() or bshape.strides()[iaxis] == 0)
return {start, last};
op.from_value({{"out_lens", get_output_lens(start, last, iaxis)}});
auto delta = bshape.lens().size() - input->get_shape().lens().size();
iaxis -= delta;
}
std::vector<instruction_ref> concats;
for(std::size_t i = 0; i < x->inputs().size(); i++)
......@@ -1009,7 +1239,7 @@ struct find_neg_unit_ops
auto ins = r.result;
auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in);
auto neg = m.insert_instruction(ins, make_op("neg"), c_in);
m.replace_instruction(ins, neg);
}
};
......@@ -1255,12 +1485,15 @@ void simplify_algebra::apply(module& m) const
{
match::find_matches(m,
find_inner_broadcast{},
find_dot_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_add_convs{},
find_conv_dot_horiz_fusion{},
find_mul_conv{},
find_mul_slice_conv{},
find_mul_dot{},
find_dot_mul{},
find_mul_add{},
find_unit_ops{},
find_neg_unit_ops{},
......
......@@ -762,7 +762,7 @@ struct find_transpose_slice
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
auto preaxis = perm[axis];
// Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct dynamic_dimensions_check
{
std::string dyn_param_str;
size_t dyn_index;
size_t min_dim;
size_t max_dim;
};
optional<dynamic_dimensions_check>
has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
{
// True if parameters contain exactly one dynamic shape with exactly one non-fixed
// dynamic_dimension.
auto is_dynamic = [](const auto& p) { return p.second.dynamic(); };
auto ps_it = std::find_if(param_shapes.begin(), param_shapes.end(), is_dynamic);
if(ps_it == param_shapes.end())
return std::nullopt;
// Check if there is a second dynamic parameter
if(std::any_of(std::next(ps_it), param_shapes.end(), is_dynamic))
return std::nullopt;
const auto& dds = ps_it->second.dyn_dims();
auto is_non_fixed = [](const auto& dd) { return not dd.is_fixed(); };
auto dds_it = std::find_if(dds.begin(), dds.end(), is_non_fixed);
if(dds_it == dds.end())
return std::nullopt;
// Check if there is a second non-fixed dynamic_dimension
if(std::any_of(std::next(dds_it), dds.end(), is_non_fixed))
return std::nullopt;
return dynamic_dimensions_check{ps_it->first,
static_cast<std::size_t>(std::distance(dds.begin(), dds_it)),
dds_it->min,
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/
void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
std::vector<module_ref> submodules;
// create submodules for each dimension size
for(size_t dim_size : migraphx::range(dd_check->min_dim, dd_check->max_dim + 1))
{
auto* submod = mpm.create_module("dim_" + std::to_string(dim_size));
// instruction map for new static shaped submodule parameters
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// create static shape using dim_size
auto static_lens = dyn_param_shape.max_lens();
static_lens.at(dd_check->dyn_index) = dim_size;
map_ins[dyn_param] = submod->add_parameter(
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
std::vector<instruction_ref> sm_inputs;
std::transform(param_names.cbegin(),
param_names.cend(),
std::back_inserter(sm_inputs),
[&](auto pn) { return mm->get_parameter(pn); });
auto output_shapes = mm->get_output_shapes();
migraphx::shape out_attr = migraphx::shape{output_shapes};
auto sm_ins = mm->add_instruction(
migraphx::make_op("select_module",
{{"output_dyn_shapes", migraphx::to_value(out_attr)}}),
sm_inputs,
submodules);
std::vector<instruction_ref> outputs(output_shapes.size());
for(size_t i = 0; i < output_shapes.size(); ++i)
{
outputs.at(i) =
mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", i}}), sm_ins);
}
mm->replace_return(outputs);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_PARALLEL_HPP
// #define MIGRAPHX_DISABLE_OMP
#include <cmath>
#include <migraphx/config.hpp>
#ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp>
......
......@@ -40,14 +40,11 @@ struct target
std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx, const compile_options&) const;
migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const;
};
MIGRAPHX_REGISTER_TARGET(target);
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment