Commit 5af9aac0 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_batch_pass' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_test_runner

parents 7b2516e0 05e81ed3
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -42,9 +42,17 @@ struct where
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).same_dims();
check_shapes{inputs, *this, true}.has(3).same_dims();
auto s1 = inputs.at(1);
auto s2 = inputs.at(2);
if(s1.dynamic() or s2.dynamic())
{
if(s1 == s2)
return s1;
MIGRAPHX_THROW("WHERE: dynamic input shapes must be the same");
}
// Compare two static shapes, returning a standard shape
if(s1 == s2 and s1.packed())
{
return s1;
......@@ -63,12 +71,12 @@ struct where
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
visit_all(result, args[1], args[2])([&](auto output, const auto x, const auto y) {
args[0].visit([&](const auto condition) {
par_for(output_shape.elements(),
par_for(dyn_out.computed_shape.elements(),
[&](auto i) { output[i] = condition[i] ? x[i] : y[i]; });
});
});
......
......@@ -140,6 +140,8 @@ template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#define MIGRAPHX_GUARD_RTGLIB_OPTIMIZE_MODULE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
/**
* Runs several passes in a loop
*/
struct optimize_module
{
std::string name() const { return "optimize_module"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -130,6 +130,7 @@ struct program
// module related api
module* create_module(const std::string& name);
module* copy_module(const std::string& in_name, const std::string& out_name);
module* get_module(const std::string& name);
const module* get_module(const std::string& name) const;
......
......@@ -28,6 +28,7 @@
#include <migraphx/value.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp>
#include <type_traits>
......@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&)
return value::object{};
}
template <class T, class U>
value to_value_impl(rank<1>, const std::pair<T, U>& x)
template <class T>
auto to_value_impl(rank<1>, const T& x) -> decltype(std::tuple_size<T>{}, value{})
{
return {x.first, x.second};
value result = value::array{};
repeat_c<std::tuple_size<T>{}>([&](auto i) { result.push_back(to_value(std::get<i>(x))); });
return result;
}
template <class T>
......@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result;
}
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if(x.has_value())
return to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x)
value to_value_impl(rank<5>, const T& x)
{
return std::int64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x)
value to_value_impl(rank<6>, const T& x)
{
return std::uint64_t{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x)
value to_value_impl(rank<7>, const T& x)
{
return double{x};
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x)
value to_value_impl(rank<8>, const T& x)
{
return x;
}
inline value to_value_impl(rank<8>, const std::string& x) { return x; }
inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x))
auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{
return migraphx_to_value(x);
}
template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value())
auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{
return x.to_value();
}
template <class T>
auto to_value_impl(rank<11>, const T& x)
auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{
value v;
......@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x)
}
template <class T>
auto from_value_impl(rank<1>, const value& v, T& x)
auto from_value_impl(rank<1>, const value& v, T& x) -> decltype(std::tuple_size<T>{}, void())
{
repeat_c<std::tuple_size<T>{}>(
[&](auto i) { std::get<i>(x) = from_value<std::tuple_element_t<i, T>>(v[i]); });
}
template <class T>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
......@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x)
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x)
auto from_value_impl(rank<3>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void())
{
x.clear();
......@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x)
}
template <class T>
auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{
x.clear();
for(auto&& e : v)
......@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi
}
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<4>, const value& v, T& x)
void from_value_impl(rank<5>, const value& v, T& x)
{
reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>;
......@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x)
});
}
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<5>, const value& v, T& x)
template <class T>
void from_value_impl(rank<6>, const value& v, optional<T>& x)
{
x = v.to<T>();
if(not v.is_null())
x = from_value<T>(v);
}
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<6>, const value& v, T& x)
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})>
void from_value_impl(rank<7>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); }
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void())
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void())
{
x.from_value(v);
}
template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{
migraphx_from_value(v, x);
}
......@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va
template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<11>{}, x);
return detail::to_value_impl(rank<12>{}, x);
}
template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<9>{}, v, x);
detail::from_value_impl(rank<10>{}, v, x);
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -243,6 +243,9 @@ struct shape
/// Return true if the shape is dynamic
bool dynamic() const;
/// Return true if this shape or any of the sub_shapes are dynamic
bool any_of_dynamic() const;
shape normalize_standard() const;
shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#define MIGRAPHX_GUARD_RTGLIB_SPLIT_SINGLE_DYN_DIM_HPP
#include <string>
#include <migraphx/program.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Split dynamic batch dimension over submodules if exactly one dimension in the parameter list
* is dynamic. Should only run on the main module.
*/
struct split_single_dyn_dim
{
std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
using instruction_set = std::unordered_set<instruction_ref>;
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template <class F>
void liveness(const module& m, F f)
{
auto implicit_deps = m.calc_implicit_deps();
instruction_set live_set;
auto rp = reverse(m);
for(auto rins : iterator_for(rp)) // NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto ins = std::prev(rins.base());
// Add live variables
auto add_live_variables = [&](const auto& inputs) {
for(auto input : inputs)
{
auto i = instruction::get_output_alias(input);
// Skip if variable comes from parent
if(not m.has_instruction(i))
continue;
live_set.insert(i);
}
};
add_live_variables(ins->inputs());
add_live_variables(implicit_deps[ins]);
// Remove last usage
auto it = live_set.find(ins);
if(it != live_set.end())
{
live_set.erase(it);
f(ins, live_set);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
{
instruction_set_map conflict_table;
liveness(m, [&](auto ins, auto live_set) {
// Skip variables that aren't allocations
if(ins->name() != allocation_op)
return;
// Skip zero allocations
if(ins->get_shape().bytes() == 0)
return;
conflict_table[ins];
for(auto i : live_set)
{
if(i == ins)
continue;
// Skip variables that aren't allocations
if(i->name() != allocation_op)
continue;
// Skip zero allocations
if(i->get_shape().bytes() == 0)
continue;
conflict_table[i].insert(ins);
conflict_table[ins].insert(i);
}
});
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
return pp.second.count(pp.first) == 0;
}));
return conflict_table;
}
// Check if intervals overlap
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
struct allocation_segment
{
using segment = std::pair<std::size_t, std::size_t>;
std::unordered_map<instruction_ref, segment> ins2segment;
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
const segment* get_segment(instruction_ref ins) const
{
auto it = ins2segment.find(ins);
if(it == ins2segment.end())
return nullptr;
return &it->second;
}
// Remove segment for an instruction
void remove(instruction_ref ins)
{
auto it = ins2segment.find(ins);
if(it != ins2segment.end())
{
ins2segment.erase(it);
}
}
std::size_t max()
{
std::size_t n = 0;
for(auto&& pp : ins2segment)
{
auto seg = pp.second;
n = std::max(n, seg.second);
}
return n;
}
template <class Iterator>
static bool overlaps(Iterator first, Iterator last, const segment& s)
{
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
}
static bool overlaps(const std::set<segment>& segments, const segment& s)
{
return overlaps(segments.begin(), segments.end(), s);
}
static auto find_gap(const std::set<segment>& segments, std::size_t n)
{
std::size_t max_end = 0;
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
if(x.second < max_end)
return false;
max_end = x.second;
if(is_overlap(x, y))
return false;
assert(y.first >= x.second);
auto k = y.first - x.second;
return (k >= n);
});
}
static std::size_t max_type_size(const shape& s)
{
return std::accumulate(
s.sub_shapes().begin(),
s.sub_shapes().end(),
s.type_size(),
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
}
static std::size_t compute_alignment(instruction_ref ins)
{
auto alignment = max_type_size(ins->get_shape());
// A rough estimate for the total number of elements
auto n = ins->get_shape().bytes() / alignment;
// Check for vectorized alignment
if(n > 4)
{
auto d = n % 4;
if(d == 0)
alignment *= 4;
if(d == 2)
alignment *= 2;
}
return alignment;
}
static segment
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
{
assert(ins->get_shape().bytes() > 0);
// Compute alignment
auto n = 1 + (ins->get_shape().bytes() - 1) / alignment;
assert(n > 0);
auto start = 0;
// Insert at end if it cant fit at the begining
if(segments.empty() or segments.begin()->first <= n)
{
auto it = find_gap(segments, n);
if(it == segments.end())
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
return x.second < y.second;
});
if(it != segments.end())
start = it->second;
}
auto s = segment{start, start + n};
assert(not overlaps(segments, s));
segments.insert(s);
return s;
}
static std::unordered_map<instruction_ref, int>
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
{
std::unordered_map<instruction_ref, int> result;
int i = 0;
for(auto ins : iterator_for(m))
{
if(not contains(conflict_table, ins))
continue;
result[ins] = i++;
}
return result;
}
// Build the allocation_color class from the conflict_table
static allocation_segment
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
{
allocation_segment as{};
std::vector<instruction_ref> conflict_queue;
// Add all allocations to the conflict_queue
std::transform(conflict_table.begin(),
conflict_table.end(),
std::back_inserter(conflict_queue),
[](auto&& pp) { return pp.first; });
auto alloc_index = create_allocation_index(m, conflict_table);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for(auto parent : conflict_queue)
{
// Sort children by size
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
conflict_table.at(parent).end());
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
}));
assert(not contains(children, parent));
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
assert(as.get_segment(parent) == nullptr);
as.add_segment(parent, next_segment(segments, parent, alignment));
}
// Reduce the number of segments
for(std::size_t n = 0; n < 3; n++)
{
for(auto parent : conflict_queue)
{
auto children = conflict_table.at(parent);
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
// Get the segment for the parent
const auto* parent_segment = as.get_segment(parent);
assert(parent_segment != nullptr);
auto s = next_segment(segments, parent, alignment);
if(s != *parent_segment and s.second <= as.max())
{
as.add_segment(parent, s);
}
}
}
return as;
}
};
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
{
std::size_t alignment = 1;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
}
return alignment;
}
void memory_coloring::apply(module& m) const
{
const std::size_t alignment = find_max_alignment(m, allocation_op);
auto conflict_table = build_conflict_table(m, allocation_op);
auto as = allocation_segment::build(m, conflict_table, alignment);
// All allocations should have a segment
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
return as.get_segment(pp.first);
}));
// Adjacent allocations should not have overlapping segments
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
auto* x = as.get_segment(pp.first);
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
auto* y = as.get_segment(ins);
assert(x and y);
return is_overlap(*x, *y);
});
}));
// Print out segments
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
{
for(auto&& pp : conflict_table)
{
std::cout << "------- conflict -------" << std::endl;
auto s1 = as.ins2segment.at(pp.first);
std::cout << s1.first << ", " << s1.second << ": ";
m.debug_print(pp.first);
for(auto ins : pp.second)
{
auto s2 = as.ins2segment.at(ins);
std::cout << s2.first << ", " << s2.second << ": ";
m.debug_print(ins);
}
}
}
// Total memory
std::size_t n = as.max() * alignment;
// Replace allocations
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
for(auto&& [ins, seg] : as.ins2segment)
{
assert(ins->name() == allocation_op);
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem);
}
// Replace zero allocation
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem);
}
// Remove scratch parameter if its not used
if(mem->outputs().empty())
{
m.remove_instruction(mem);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -822,7 +822,8 @@ static void print_make_op(std::ostream& os, const operation& op)
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
os << "migraphx.shape(type=" << to_json_string(s.type_string())
<< ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
......
......@@ -30,13 +30,16 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// different attributes
// 1) use_input(default)/use_output
// 2) use_rank(default)/use_len
// 3) clip_min(default)/not_clip_min
// 3.1) include_min(default)/exclude_min
// 4) clip_max(default)/not_clip_max
// 4.1) exclude_max(default)/include_max
/**
* Parameters:
* vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
* normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes,
const value& val,
......@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return result;
}
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
bool tuned = false;
......@@ -160,7 +168,6 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{
auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start))
......
......@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph);
std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
......@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version;
}
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
......@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
if(not inlining)
{
// add the return instuction
mod->add_return(output_ins);
// remove instructions added in this mod
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
return output_ins;
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
......@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!");
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
......@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph
parser.parse_graph(then_mdl, then_graph);
(void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph);
(void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
......
......@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
(void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
......@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice
// to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5)
{
migraphx::argument step_arg = args.back()->eval();
......@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
// If axes arg is not given, the default is all of them.
if(op.axes.empty())
{
std::vector<int64_t> axes(args[0]->get_shape().lens().size());
std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes;
}
......@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size()))
{
if(steps[i] >= 0)
......@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_trilu : op_parser<parse_trilu>
{
std::vector<op_desc> operators() const { return {{"Trilu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
assert(input_shape.ndim() >= 2);
auto input_lens = input_shape.lens();
size_t num_rows = *(input_lens.rbegin() + 1);
size_t num_cols = input_lens.back();
int k = 0;
bool upper = true;
if(args.size() > 1)
{
auto arg_k = args[1]->eval();
check_arg_empty(arg_k, "PARSE_TRILU: dynamic k not supported");
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
}
shape::type_t output_type = args[0]->get_shape().type();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
k++;
}
auto mask = info.add_literal(
migraphx::literal{migraphx::shape{output_type, {num_rows, num_cols}}, mask_mat});
return info.add_broadcastable_binary_op("mul", mask, args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -40,6 +40,19 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// TODO: broadcasting for dynamic shapes is only implemented
// for binary ops at time of writing, not ternary ops.
// When it becomes available, add multibroadcasting steps in the dynamic shape case.
// For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else if(std::none_of(
args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{
// If shapes are static and any are broadcasted, insert multibroadcast ops
auto lens =
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
......@@ -63,6 +76,9 @@ struct parse_where : op_parser<parse_where>
return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
}
};
} // namespace onnx
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void optimize_module::apply(module_pass_manager& mpm) const
{
for(int i = 0; i < 2; i++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -39,6 +39,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace)
{
......@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override
{
assert(mod);
timer ts{};
using seconds = std::chrono::duration<double>;
trace("Module: ", mod->name(), ", Pass: ", p.name());
const double t1 = ts.record<seconds>();
assert(mod->validate() == mod->end());
if(enabled(MIGRAPHX_TIME_PASSES{}))
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto ms = time<milliseconds>([&] { p.apply(*this); });
std::cout << p.name() << ": " << ms << "ms\n";
}
else
{
p.apply(*this);
}
trace(*mod);
validate_pass(*mod, p, *t);
const double t2 = ts.record<seconds>();
trace("Pass: ", p.name(), " completed in (s): ", (t2 - t1));
}
};
......
......@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert(not this->is_compiled());
this->impl->target_name = t.name();
this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout};
options.trace(*this);
options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
// Validate and finalize
for(const auto& mod : reverse(mods))
{
......@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name);
"} for parameter: " + param_name +
" should be: " + to_string(ins->get_shape()));
}
return param;
}));
......@@ -380,7 +379,7 @@ std::vector<argument> generic_eval(const module* mod,
}));
}
assert(results.find(ins) != results.end());
if(not ins->get_shape().dynamic())
if(not ins->get_shape().any_of_dynamic())
{
assert(results.at(ins).get_shape() == ins->get_shape());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment