Commit 84725d72 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_batch_pass

parents 7f1e8443 bfd77388
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <type_traits> #include <type_traits>
...@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&) ...@@ -60,11 +61,12 @@ value to_value_impl(rank<0>, const T&)
return value::object{}; return value::object{};
} }
template <class T, class U> template <class T>
value to_value_impl(rank<1>, const std::pair<T, U>& x) auto to_value_impl(rank<1>, const T& x) -> decltype(std::tuple_size<T>{}, value{})
{ {
value result = value::array{};
return {x.first, x.second}; repeat_c<std::tuple_size<T>{}>([&](auto i) { result.push_back(to_value(std::get<i>(x))); });
return result;
} }
template <class T> template <class T>
...@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x) ...@@ -86,46 +88,55 @@ value to_value_impl(rank<3>, const T& x)
return result; return result;
} }
template <class T>
auto to_value_impl(rank<4>, const optional<T>& x)
{
value result{};
if(x.has_value())
return to_value(*x);
return result;
}
template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_signed<T>{})>
value to_value_impl(rank<4>, const T& x) value to_value_impl(rank<5>, const T& x)
{ {
return std::int64_t{x}; return std::int64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_unsigned<T>{})>
value to_value_impl(rank<5>, const T& x) value to_value_impl(rank<6>, const T& x)
{ {
return std::uint64_t{x}; return std::uint64_t{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
value to_value_impl(rank<6>, const T& x) value to_value_impl(rank<7>, const T& x)
{ {
return double{x}; return double{x};
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
value to_value_impl(rank<7>, const T& x) value to_value_impl(rank<8>, const T& x)
{ {
return x; return x;
} }
inline value to_value_impl(rank<8>, const std::string& x) { return x; } inline value to_value_impl(rank<9>, const std::string& x) { return x; }
template <class T> template <class T>
auto to_value_impl(rank<9>, const T& x) -> decltype(migraphx_to_value(x)) auto to_value_impl(rank<10>, const T& x) -> decltype(migraphx_to_value(x))
{ {
return migraphx_to_value(x); return migraphx_to_value(x);
} }
template <class T> template <class T>
auto to_value_impl(rank<10>, const T& x) -> decltype(x.to_value()) auto to_value_impl(rank<11>, const T& x) -> decltype(x.to_value())
{ {
return x.to_value(); return x.to_value();
} }
template <class T> template <class T>
auto to_value_impl(rank<11>, const T& x) auto to_value_impl(rank<12>, const T& x)
-> decltype(migraphx_to_value(std::declval<value&>(), x), value{}) -> decltype(migraphx_to_value(std::declval<value&>(), x), value{})
{ {
value v; value v;
...@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x) ...@@ -144,7 +155,14 @@ void from_value_impl(rank<0>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<1>, const value& v, T& x) auto from_value_impl(rank<1>, const value& v, T& x) -> decltype(std::tuple_size<T>{}, void())
{
repeat_c<std::tuple_size<T>{}>(
[&](auto i) { std::get<i>(x) = from_value<std::tuple_element_t<i, T>>(v[i]); });
}
template <class T>
auto from_value_impl(rank<2>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void()) -> decltype(x.insert(x.end(), *x.begin()), void())
{ {
x.clear(); x.clear();
...@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x) ...@@ -153,7 +171,7 @@ auto from_value_impl(rank<1>, const value& v, T& x)
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<typename T::value_type>{})>
auto from_value_impl(rank<2>, const value& v, T& x) auto from_value_impl(rank<3>, const value& v, T& x)
-> decltype(x.insert(x.end(), *x.begin()), void()) -> decltype(x.insert(x.end(), *x.begin()), void())
{ {
x.clear(); x.clear();
...@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x) ...@@ -170,7 +188,7 @@ auto from_value_impl(rank<2>, const value& v, T& x)
} }
template <class T> template <class T>
auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void()) auto from_value_impl(rank<4>, const value& v, T& x) -> decltype(x.insert(*x.begin()), void())
{ {
x.clear(); x.clear();
for(auto&& e : v) for(auto&& e : v)
...@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi ...@@ -178,7 +196,7 @@ auto from_value_impl(rank<3>, const value& v, T& x) -> decltype(x.insert(*x.begi
} }
template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})> template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void from_value_impl(rank<4>, const value& v, T& x) void from_value_impl(rank<5>, const value& v, T& x)
{ {
reflect_each(x, [&](auto& y, const std::string& name) { reflect_each(x, [&](auto& y, const std::string& name) {
using type = std::decay_t<decltype(y)>; using type = std::decay_t<decltype(y)>;
...@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x) ...@@ -187,28 +205,29 @@ void from_value_impl(rank<4>, const value& v, T& x)
}); });
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})> template <class T>
void from_value_impl(rank<5>, const value& v, T& x) void from_value_impl(rank<6>, const value& v, optional<T>& x)
{ {
x = v.to<T>(); if(not v.is_null())
x = from_value<T>(v);
} }
template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})>
void from_value_impl(rank<6>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<7>, const value& v, std::string& x) { x = v.to<std::string>(); } inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<8>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va ...@@ -218,13 +237,13 @@ auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(migraphx_from_va
template <class T> template <class T>
value to_value(const T& x) value to_value(const T& x)
{ {
return detail::to_value_impl(rank<11>{}, x); return detail::to_value_impl(rank<12>{}, x);
} }
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<9>{}, v, x); detail::from_value_impl(rank<10>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/memory_coloring.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
#include <unordered_map>
#include <map>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DEBUG_MEMORY_COLORING);
using instruction_set = std::unordered_set<instruction_ref>;
using instruction_set_map = std::unordered_map<instruction_ref, instruction_set>;
// This will do liveness analysis on the module, and it will call the
// function `f` with the instruction and the set of the other instructions
// that are live
template <class F>
void liveness(const module& m, F f)
{
auto implicit_deps = m.calc_implicit_deps();
instruction_set live_set;
auto rp = reverse(m);
for(auto rins : iterator_for(rp)) // NOLINT
{
// The base iterator is one ahead, so we need to use the previous iterator
auto ins = std::prev(rins.base());
// Add live variables
auto add_live_variables = [&](const auto& inputs) {
for(auto input : inputs)
{
auto i = instruction::get_output_alias(input);
// Skip if variable comes from parent
if(not m.has_instruction(i))
continue;
live_set.insert(i);
}
};
add_live_variables(ins->inputs());
add_live_variables(implicit_deps[ins]);
// Remove last usage
auto it = live_set.find(ins);
if(it != live_set.end())
{
live_set.erase(it);
f(ins, live_set);
}
}
}
// This will build the conflict table or interference graph. This is
// essentially a map from one instruction to a set of instruction that are
// used together. Each instruction will be the allocation instruction.
instruction_set_map build_conflict_table(const module& m, std::string allocation_op)
{
instruction_set_map conflict_table;
liveness(m, [&](auto ins, auto live_set) {
// Skip variables that aren't allocations
if(ins->name() != allocation_op)
return;
// Skip zero allocations
if(ins->get_shape().bytes() == 0)
return;
conflict_table[ins];
for(auto i : live_set)
{
if(i == ins)
continue;
// Skip variables that aren't allocations
if(i->name() != allocation_op)
continue;
// Skip zero allocations
if(i->get_shape().bytes() == 0)
continue;
conflict_table[i].insert(ins);
conflict_table[ins].insert(i);
}
});
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [](auto&& pp) {
return pp.second.count(pp.first) == 0;
}));
return conflict_table;
}
// Check if intervals overlap
bool is_overlap(std::pair<std::size_t, std::size_t> x, std::pair<std::size_t, std::size_t> y)
{
return std::max(x.first, y.first) < std::min(x.second, y.second);
}
struct allocation_segment
{
using segment = std::pair<std::size_t, std::size_t>;
std::unordered_map<instruction_ref, segment> ins2segment;
const segment* add_segment(instruction_ref ins, segment s) { return &(ins2segment[ins] = s); }
const segment* get_segment(instruction_ref ins) const
{
auto it = ins2segment.find(ins);
if(it == ins2segment.end())
return nullptr;
return &it->second;
}
// Remove segment for an instruction
void remove(instruction_ref ins)
{
auto it = ins2segment.find(ins);
if(it != ins2segment.end())
{
ins2segment.erase(it);
}
}
std::size_t max()
{
std::size_t n = 0;
for(auto&& pp : ins2segment)
{
auto seg = pp.second;
n = std::max(n, seg.second);
}
return n;
}
template <class Iterator>
static bool overlaps(Iterator first, Iterator last, const segment& s)
{
return std::any_of(first, last, [&](auto&& t) { return is_overlap(s, t); });
}
static bool overlaps(const std::set<segment>& segments, const segment& s)
{
return overlaps(segments.begin(), segments.end(), s);
}
static auto find_gap(const std::set<segment>& segments, std::size_t n)
{
std::size_t max_end = 0;
return std::adjacent_find(segments.begin(), segments.end(), [&](segment x, segment y) {
if(x.second < max_end)
return false;
max_end = x.second;
if(is_overlap(x, y))
return false;
assert(y.first >= x.second);
auto k = y.first - x.second;
return (k >= n);
});
}
static std::size_t max_type_size(const shape& s)
{
return std::accumulate(
s.sub_shapes().begin(),
s.sub_shapes().end(),
s.type_size(),
[](auto size, const auto& sub) { return std::max(size, max_type_size(sub)); });
}
static std::size_t compute_alignment(instruction_ref ins)
{
auto alignment = max_type_size(ins->get_shape());
// A rough estimate for the total number of elements
auto n = ins->get_shape().bytes() / alignment;
// Check for vectorized alignment
if(n > 4)
{
auto d = n % 4;
if(d == 0)
alignment *= 4;
if(d == 2)
alignment *= 2;
}
return alignment;
}
static segment
next_segment(std::set<segment>& segments, instruction_ref ins, std::size_t alignment)
{
assert(ins->get_shape().bytes() > 0);
// Compute alignment
auto n = 1 + (ins->get_shape().bytes() - 1) / alignment;
assert(n > 0);
auto start = 0;
// Insert at end if it cant fit at the begining
if(segments.empty() or segments.begin()->first <= n)
{
auto it = find_gap(segments, n);
if(it == segments.end())
it = std::max_element(segments.begin(), segments.end(), [&](segment x, segment y) {
return x.second < y.second;
});
if(it != segments.end())
start = it->second;
}
auto s = segment{start, start + n};
assert(not overlaps(segments, s));
segments.insert(s);
return s;
}
static std::unordered_map<instruction_ref, int>
create_allocation_index(const module& m, const instruction_set_map& conflict_table)
{
std::unordered_map<instruction_ref, int> result;
int i = 0;
for(auto ins : iterator_for(m))
{
if(not contains(conflict_table, ins))
continue;
result[ins] = i++;
}
return result;
}
// Build the allocation_color class from the conflict_table
static allocation_segment
build(const module& m, const instruction_set_map& conflict_table, std::size_t alignment)
{
allocation_segment as{};
std::vector<instruction_ref> conflict_queue;
// Add all allocations to the conflict_queue
std::transform(conflict_table.begin(),
conflict_table.end(),
std::back_inserter(conflict_queue),
[](auto&& pp) { return pp.first; });
auto alloc_index = create_allocation_index(m, conflict_table);
// Sort the conflict queue so we process the allocation with the most
// number of adjacent allocations first
std::sort(conflict_queue.begin(), conflict_queue.end(), by(std::greater<>{}, [&](auto x) {
return std::make_tuple(
conflict_table.at(x).size(), x->get_shape().bytes(), alloc_index.at(x));
}));
// Process the conflict_queue, we refer to the current allocation as
// the parent and the adjacent allocations as children
for(auto parent : conflict_queue)
{
// Sort children by size
std::vector<instruction_ref> children(conflict_table.at(parent).begin(),
conflict_table.at(parent).end());
std::sort(children.begin(), children.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(x->get_shape().bytes(), alloc_index.at(x));
}));
assert(not contains(children, parent));
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
assert(as.get_segment(parent) == nullptr);
as.add_segment(parent, next_segment(segments, parent, alignment));
}
// Reduce the number of segments
for(std::size_t n = 0; n < 3; n++)
{
for(auto parent : conflict_queue)
{
auto children = conflict_table.at(parent);
// This set is to track the segments already processed
std::set<segment> segments;
// Add all segments for the children to the segments already processed
transform_if(
children.begin(),
children.end(),
std::inserter(segments, segments.begin()),
[&](auto child) { return as.get_segment(child); },
[&](auto child) { return *as.get_segment(child); });
// Get the segment for the parent
const auto* parent_segment = as.get_segment(parent);
assert(parent_segment != nullptr);
auto s = next_segment(segments, parent, alignment);
if(s != *parent_segment and s.second <= as.max())
{
as.add_segment(parent, s);
}
}
}
return as;
}
};
static std::size_t find_max_alignment(const module& m, const std::string& allocation_op)
{
std::size_t alignment = 1;
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
alignment = std::max(allocation_segment::compute_alignment(ins), alignment);
}
return alignment;
}
void memory_coloring::apply(module& m) const
{
const std::size_t alignment = find_max_alignment(m, allocation_op);
auto conflict_table = build_conflict_table(m, allocation_op);
auto as = allocation_segment::build(m, conflict_table, alignment);
// All allocations should have a segment
assert(std::all_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
return as.get_segment(pp.first);
}));
// Adjacent allocations should not have overlapping segments
assert(std::none_of(conflict_table.begin(), conflict_table.end(), [&](auto&& pp) {
auto* x = as.get_segment(pp.first);
return std::any_of(pp.second.begin(), pp.second.end(), [&](auto ins) {
auto* y = as.get_segment(ins);
assert(x and y);
return is_overlap(*x, *y);
});
}));
// Print out segments
if(enabled(MIGRAPHX_DEBUG_MEMORY_COLORING{}))
{
for(auto&& pp : conflict_table)
{
std::cout << "------- conflict -------" << std::endl;
auto s1 = as.ins2segment.at(pp.first);
std::cout << s1.first << ", " << s1.second << ": ";
m.debug_print(pp.first);
for(auto ins : pp.second)
{
auto s2 = as.ins2segment.at(ins);
std::cout << s2.first << ", " << s2.second << ": ";
m.debug_print(ins);
}
}
}
// Total memory
std::size_t n = as.max() * alignment;
// Replace allocations
auto mem = m.add_parameter("scratch", shape{shape::int8_type, {n}});
for(auto&& [ins, seg] : as.ins2segment)
{
assert(ins->name() == allocation_op);
auto s = ins->get_shape();
std::size_t offset = seg.first * alignment;
assert(offset < n);
m.replace_instruction(ins, op::load{s, offset}, mem);
}
// Replace zero allocation
for(auto ins : iterator_for(m))
{
if(ins->name() != allocation_op)
continue;
assert(ins->get_shape().bytes() == 0);
m.replace_instruction(ins, op::load{ins->get_shape(), 0}, mem);
}
// Remove scratch parameter if its not used
if(mem->outputs().empty())
{
m.remove_instruction(mem);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -822,7 +822,8 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -822,7 +822,8 @@ static void print_make_op(std::ostream& os, const operation& op)
static void print_py_shape(std::ostream& os, const migraphx::shape& s) static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{ {
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens()); os << "migraphx.shape(type=" << to_json_string(s.type_string())
<< ", lens=" << to_json_string(s.lens());
if(not s.standard()) if(not s.standard())
os << ", strides=" << to_json_string(s.strides()); os << ", strides=" << to_json_string(s.strides());
os << ")"; os << ")";
......
...@@ -30,13 +30,16 @@ ...@@ -30,13 +30,16 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// different attributes /**
// 1) use_input(default)/use_output * Parameters:
// 2) use_rank(default)/use_len * vec: the vector attribute to normalize
// 3) clip_min(default)/not_clip_min * axes: the operator's axes attribute if it exists, empty otherwise
// 3.1) include_min(default)/exclude_min * val: the normalize_axes key and options. Ex: normalize["axes"] =
// 4) clip_max(default)/not_clip_max * value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling
// 4.1) exclude_max(default)/include_max * normalize_attributes(op&, lens)
*
* See normalize_attribute.hpp for explaining the options.
*/
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
...@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val) ...@@ -151,6 +154,11 @@ auto tune_pad_attribute(const value& val)
return result; return result;
} }
/**
* Assumptions:
* Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input.
*/
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
{ {
bool tuned = false; bool tuned = false;
...@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -158,9 +166,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto val = op.to_value(); auto val = op.to_value();
if(attrs.contains("normalize_padding")) if(attrs.contains("normalize_padding"))
{ {
auto padding = val.at(attrs.at("normalize_padding").to<std::string>()); auto padding = val.at(attrs.at("normalize_padding").to<std::string>());
auto padding_size = padding.size(); auto padding_size = padding.size();
// for now, assume the dimensions to pad start at dim 2
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (lens.size() - padding_start))
......
...@@ -113,7 +113,8 @@ struct onnx_parser ...@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = ""); void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size); void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph); std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size) ...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins), std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
// add the return instuction if(not inlining)
mod->add_return(output_ins); {
// add the return instuction
mod->add_return(output_ins);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
// remove instructions added in this mod return output_ins;
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if> ...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"); " condition input can have only one element!");
} }
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if"; std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name); module_ref then_mdl = parser.prog.create_module(then_name);
...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if> ...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name); module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph // parse the then sub_graph
parser.parse_graph(then_mdl, then_graph); (void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph // parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph); (void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
......
...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop> ...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name); module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph // parse the sub_graph
parser.parse_graph(sub_mod, sub_graph); (void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction( auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod}); make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps; std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
// If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(op.axes.empty())
{ {
std::vector<int64_t> axes(args[0]->get_shape().lens().size()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; op.axes = axes;
} }
...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size()); assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size()); assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(steps.size()))
{ {
if(steps[i] >= 0) if(steps[i] >= 0)
...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{ {
std::vector<int64_t> nsteps; std::vector<int64_t> nsteps;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where> ...@@ -40,28 +40,44 @@ struct parse_where : op_parser<parse_where>
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto lens = // TODO: broadcasting for dynamic shapes is only implemented
compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens()); // for binary ops at time of writing, not ternary ops.
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens()); // When it becomes available, add multibroadcasting steps in the dynamic shape case.
if(args[0]->get_shape().lens() != lens) // For now for dynamic shapes, just insert the Where op. All shapes must be the
// same for it to succeed.
if(std::all_of(args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[0] = return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
} }
else if(std::none_of(
if(args[1]->get_shape().lens() != lens) args.begin(), args.end(), [](auto v) { return v->get_shape().dynamic(); }))
{ {
args[1] = // If shapes are static and any are broadcasted, insert multibroadcast ops
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]); auto lens =
} compute_broadcasted_lens(args[0]->get_shape().lens(), args[1]->get_shape().lens());
lens = compute_broadcasted_lens(lens, args[2]->get_shape().lens());
if(args[0]->get_shape().lens() != lens)
{
args[0] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[0]);
}
if(args[2]->get_shape().lens() != lens) if(args[1]->get_shape().lens() != lens)
{ {
args[2] = args[1] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]); info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[1]);
} }
if(args[2]->get_shape().lens() != lens)
{
args[2] =
info.add_instruction(make_op("multibroadcast", {{"out_lens", lens}}), args[2]);
}
return info.add_instruction(make_op("where"), args[0], args[1], args[2]); return info.add_instruction(make_op("where"), args[0], args[1], args[2]);
}
else
MIGRAPHX_THROW("PARSE_WHERE: doesn't support mixed static and dynamic shape inputs");
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/optimize_module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/propagate_constant.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void optimize_module::apply(module_pass_manager& mpm) const
{
for(int i = 0; i < 2; i++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -39,6 +39,7 @@ namespace migraphx { ...@@ -39,6 +39,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PASSES);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_PASSES);
void validate_pass(module& mod, const pass& p, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{ {
...@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager ...@@ -94,19 +95,19 @@ struct module_pm : module_pass_manager
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
timer ts{};
using seconds = std::chrono::duration<double>;
trace("Module: ", mod->name(), ", Pass: ", p.name());
const double t1 = ts.record<seconds>();
assert(mod->validate() == mod->end()); assert(mod->validate() == mod->end());
p.apply(*this); if(enabled(MIGRAPHX_TIME_PASSES{}))
{
using milliseconds = std::chrono::duration<double, std::milli>;
auto ms = time<milliseconds>([&] { p.apply(*this); });
std::cout << p.name() << ": " << ms << "ms\n";
}
else
{
p.apply(*this);
}
trace(*mod); trace(*mod);
validate_pass(*mod, p, *t); validate_pass(*mod, p, *t);
const double t2 = ts.record<seconds>();
trace("Pass: ", p.name(), " completed in (s): ", (t2 - t1));
} }
}; };
......
...@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options) ...@@ -210,17 +210,15 @@ void program::compile(const target& t, compile_options options)
assert(not this->is_compiled()); assert(not this->is_compiled());
this->impl->target_name = t.name(); this->impl->target_name = t.name();
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
options.trace = tracer{std::cout}; options.trace = tracer{std::cout};
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace); run_passes(*this, passes, options.trace);
auto mods = this->get_modules(); auto mods = this->get_modules();
// Validate and finalize // Validate and finalize
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
...@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod, ...@@ -336,7 +334,8 @@ std::vector<argument> generic_eval(const module* mod,
if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape()) if(not ins->get_shape().dynamic() and param.get_shape() != ins->get_shape())
{ {
MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) + MIGRAPHX_THROW("Incorrect shape {" + to_string(param.get_shape()) +
"} for parameter: " + param_name); "} for parameter: " + param_name +
" should be: " + to_string(ins->get_shape()));
} }
return param; return param;
})); }));
......
...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -329,15 +329,21 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("is_compiled", &migraphx::program::is_compiled) .def("is_compiled", &migraphx::program::is_compiled)
.def( .def(
"compile", "compile",
[](migraphx::program& p, const migraphx::target& t, bool offload_copy, bool fast_math) { [](migraphx::program& p,
const migraphx::target& t,
bool offload_copy,
bool fast_math,
bool exhaustive_tune) {
migraphx::compile_options options; migraphx::compile_options options;
options.offload_copy = offload_copy; options.offload_copy = offload_copy;
options.fast_math = fast_math; options.fast_math = fast_math;
options.exhaustive_tune = exhaustive_tune;
p.compile(t, options); p.compile(t, options);
}, },
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true,
py::arg("exhaustive_tune") = false)
.def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); }) .def("get_main_module", [](const migraphx::program& p) { return p.get_main_module(); })
.def( .def(
"create_module", "create_module",
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/common.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
...@@ -340,12 +341,18 @@ struct find_inner_broadcast ...@@ -340,12 +341,18 @@ struct find_inner_broadcast
std::back_inserter(inputs), std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); }); [](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) { if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape(); return i->get_shape() != inputs.front()->get_shape() and
i->get_shape().elements() != 1;
})) }))
return; return;
auto op = m.insert_instruction(ins, ins->get_operator(), inputs); auto b_it = std::find_if(broadcasts.begin(), broadcasts.end(), [&](auto i) {
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); return not i->get_shape().scalar();
});
if(b_it == broadcasts.end())
b_it = broadcasts.begin();
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
m.replace_instruction(ins, (*b_it)->get_operator(), op);
} }
}; };
...@@ -975,7 +982,7 @@ struct find_neg_unit_ops ...@@ -975,7 +982,7 @@ struct find_neg_unit_ops
auto ins = r.result; auto ins = r.result;
auto c_in = r.instructions["x"]; auto c_in = r.instructions["x"];
auto neg = m.add_instruction(make_op("neg"), c_in); auto neg = m.insert_instruction(ins, make_op("neg"), c_in);
m.replace_instruction(ins, neg); m.replace_instruction(ins, neg);
} }
}; };
......
##################################################################################### # ####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc) list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip)
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen) ...@@ -33,6 +33,8 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
...@@ -46,9 +48,10 @@ add_library(compile_for_gpu INTERFACE) ...@@ -46,9 +48,10 @@ add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored) target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device") message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device) target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
endif() endif()
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
...@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR ...@@ -60,11 +63,13 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(kernel_file_check EXCLUDE_FROM_ALL) add_library(kernel_file_check EXCLUDE_FROM_ALL)
foreach(KERNEL_FILE ${KERNEL_FILES}) foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE) get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n") file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp) target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach() endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256) target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>) target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu) target_link_libraries(kernel_file_check compile_for_gpu)
...@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX) ...@@ -125,6 +130,7 @@ function(register_migraphx_gpu_ops PREFIX)
register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp) register_op(migraphx_gpu HEADER migraphx/gpu/${OP}.hpp OPERATORS gpu::${PREFIX}${OP} INCLUDES migraphx/gpu/context.hpp)
endforeach() endforeach()
endfunction() endfunction()
register_migraphx_gpu_ops(hip_ register_migraphx_gpu_ops(hip_
argmax argmax
argmin argmin
...@@ -146,47 +152,26 @@ register_migraphx_gpu_ops(miopen_ ...@@ -146,47 +152,26 @@ register_migraphx_gpu_ops(miopen_
lrn lrn
pooling pooling
) )
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/int8_gemm_pack.hpp HEADER migraphx/gpu/int8_gemm_pack.hpp
OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b OPERATORS gpu::hip_int8_gemm_pack_a gpu::hip_int8_gemm_pack_b
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot> OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution> OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp) INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
# look for offload bundler
get_filename_component(CMAKE_CXX_COMPILER_PATH "${CMAKE_CXX_COMPILER}" PATH)
if(CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+$")
find_program(MIGRAPHX_OFFLOADBUNDLER_BIN clang-offload-bundler
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATH_SUFFIXES bin
PATHS /opt/rocm/llvm
)
else()
find_program(MIGRAPHX_EXTRACT_KERNEL extractkernel
PATH_SUFFIXES bin
HINTS ${CMAKE_CXX_COMPILER_PATH}
PATHS
/opt/rocm/hip
/opt/rocm/hcc
/opt/rocm
)
endif()
message(STATUS "clang-offload-bundler: ${MIGRAPHX_OFFLOADBUNDLER_BIN}")
message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR # Find package rocMLIR
find_package(rocMLIR 1.0.0 CONFIG REQUIRED) find_package(rocMLIR 1.0.0 CONFIG REQUIRED)
...@@ -195,36 +180,38 @@ if(MIGRAPHX_ENABLE_MLIR) ...@@ -195,36 +180,38 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler) target_link_libraries(migraphx_gpu PUBLIC rocMLIR::rockCompiler)
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
if(MIGRAPHX_USE_HIPRTC) if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else() else()
# Get flags needed to compile hip message(STATUS "MIGraphX is using HIP Clang")
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") # Get flags needed to compile hip
target_compile_definitions(migraphx_gpu PRIVATE include(TargetFlags)
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}" target_flags(HIP_COMPILER_FLAGS hip::device)
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}" # Remove cuda arch flags
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}" string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
"-DMIGRAPHX_USE_HIPRTC=0" string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER) # Skip library paths since hip will incorrectly treat it as a source file
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER) string(APPEND HIP_COMPILER_FLAGS " ")
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
...@@ -233,10 +220,9 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION) ...@@ -233,10 +220,9 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API) check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# TODO: Set default to HAS_FIND_2_API set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")
set(MIGRAPHX_USE_FIND_2_API OFF CACHE BOOL "")
if(MIGRAPHX_USE_FIND_2_API) if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen") message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else() else()
...@@ -250,16 +236,13 @@ else() ...@@ -250,16 +236,13 @@ else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
# Workaround broken rocblas headers
target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas) target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels) target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver) add_subdirectory(driver)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_gpu migraphx_device compile_for_gpu TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
) )
...@@ -29,10 +29,9 @@ ...@@ -29,10 +29,9 @@
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#if MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else #else
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp> #include <migraphx/process.hpp>
...@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); ...@@ -48,9 +47,10 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
std::string hiprtc_error(hiprtcResult err, const std::string& msg) std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{ {
...@@ -143,25 +143,29 @@ struct hiprtc_program ...@@ -143,25 +143,29 @@ struct hiprtc_program
options.end(), options.end(),
std::back_inserter(c_options), std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
std::cerr << log() << std::endl; auto prog_log = log();
if(not prog_log.empty())
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS) if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Compilation failed."); MIGRAPHX_HIPRTC_THROW(result, "Compilation failed.");
} }
std::string log() std::string log() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n < 2) if(n == 0)
return {}; return {};
std::vector<char> buffer(n); std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0); assert(buffer.back() != 0);
return {buffer.begin(), buffer.end() - 1}; return buffer;
} }
std::vector<char> get_code_obj() std::vector<char> get_code_obj() const
{ {
std::size_t n = 0; std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
...@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -176,6 +180,17 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
hiprtc_program prog(srcs); hiprtc_program prog(srcs);
auto options = split_string(params, ' '); auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
// remove following three compilation flags for HIPRTC once fixes from hipRTC are available in
if(enabled(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS{}))
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
...@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -183,7 +198,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
})) }))
options.push_back("-std=c++17"); options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc"); options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3")); options.push_back("-O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat"); options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch); options.push_back("--offload-arch=" + arch);
prog.compile(options); prog.compile(options);
...@@ -192,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -192,12 +207,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
#else // MIGRAPHX_USE_HIPRTC #else // MIGRAPHX_USE_HIPRTC
bool is_hcc_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc");
return result;
}
bool is_hip_clang_compiler() bool is_hip_clang_compiler()
{ {
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++"); static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "clang++");
...@@ -221,7 +230,7 @@ std::vector<std::vector<char>> ...@@ -221,7 +230,7 @@ std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
assert(not srcs.empty()); assert(not srcs.empty());
if(not is_hcc_compiler() and not is_hip_clang_compiler()) if(not is_hip_clang_compiler())
MIGRAPHX_THROW("Unknown hip compiler: " + MIGRAPHX_THROW("Unknown hip compiler: " +
std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER))); std::string(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER)));
...@@ -231,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -231,16 +240,9 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(enabled(MIGRAPHX_GPU_DEBUG_SYM{})) if(enabled(MIGRAPHX_GPU_DEBUG_SYM{}))
params += " -g"; params += " -g";
params += " -c"; params += " -c";
if(is_hcc_compiler()) params += " --offload-arch=" + arch;
{ params += " --cuda-device-only";
params += " -amdgpu-target=" + arch; params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
else if(is_hip_clang_compiler())
{
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG"; params += " -DMIGRAPHX_DEBUG";
...@@ -255,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -255,24 +257,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(has_compiler_launcher()) if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER); compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif #endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
process{MIGRAPHX_STRINGIZE(MIGRAPHX_EXTRACT_KERNEL) + std::string{" -i "} +
obj_path.string()}
.cwd(obj_path.parent_path());
for(const auto& entry : fs::directory_iterator{obj_path.parent_path()})
{
const auto& hsaco_path = entry.path();
if(not fs::is_regular_file(hsaco_path))
continue;
if(hsaco_path.extension() != ".hsaco")
continue;
return hsaco_path;
}
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{})) if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{ {
for(const auto& src : srcs) for(const auto& src : srcs)
...@@ -292,6 +276,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -292,6 +276,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
#endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
{ {
std::vector<std::string> items(count); std::vector<std::string> items(count);
...@@ -299,8 +285,6 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -299,8 +285,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ","); return join_strings(items, ",");
} }
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx_kernels.hpp> #include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs) ...@@ -80,6 +79,7 @@ std::string generate_args_hpp(const std::vector<shape>& inputs)
#include <migraphx/kernels/args.hpp> #include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp> #include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) ...@@ -112,14 +112,8 @@ inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
#ifdef MIGRAPHX_USE_CLANG_TIDY #ifdef MIGRAPHX_USE_CLANG_TIDY
#define MIGRAPHX_DEVICE_SHARED #define MIGRAPHX_DEVICE_SHARED
#else #else
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPHX_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPHX_DEVICE_SHARED __shared__ #define MIGRAPHX_DEVICE_SHARED __shared__
#endif #endif
#endif
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -36,6 +36,7 @@ namespace gpu { ...@@ -36,6 +36,7 @@ namespace gpu {
namespace device { namespace device {
#ifdef MIGRAPHX_NO_DPP #ifdef MIGRAPHX_NO_DPP
template <index_int N, template <index_int N,
class Op, class Op,
class T, class T,
...@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f) ...@@ -62,6 +63,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
} }
return buffer[0]; return buffer[0];
} }
#else #else
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; } constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
...@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x) ...@@ -96,11 +98,7 @@ __device__ T dpp_mov(T& x)
input.data = x; input.data = x;
for(index_int i = 0; i < n; i++) for(index_int i = 0; i < n; i++)
{ {
#if defined(__HCC__)
output.reg[i] = __llvm_amdgcn_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#else
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl); output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
#endif
} }
return output.data; return output.data;
} }
...@@ -310,4 +308,4 @@ void reduce(hipStream_t stream, ...@@ -310,4 +308,4 @@ void reduce(hipStream_t stream,
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif // MIGRAPHX_NO_DPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment