Commit 32b83c9c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into inner_bcast_fix

parents 92f5a6cd 434a06cf
...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op ...@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots // Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots; std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules(); auto mods = this->get_modules();
for(auto* mod : mods) for(const auto* mod : mods)
{ {
for(const auto& ins : *mod) for(const auto& ins : *mod)
{ {
...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment ...@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) { ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
auto& ctx = contexts[ins->get_target_id()]; const auto& ctx = contexts[ins->get_target_id()];
ctx.finish(); ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{}; timer t{};
...@@ -624,7 +624,7 @@ std::string get_migraphx_version() ...@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file. if any changes occur to the format of the MXR file.
*/ */
const int program_file_version = 6; const int program_file_version = 7;
value program::to_value() const value program::to_value() const
{ {
...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod, ...@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std::back_inserter(module_inputs), std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); }); [&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs) for(const auto& smod : module_inputs)
{ {
mod_from_val(smod, v, instructions, map_mods); mod_from_val(smod, v, instructions, map_mods);
} }
...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules() ...@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std::vector<module*> unused; std::vector<module*> unused;
generic_get_unused_modules( generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused)); impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused) for(const auto* m : unused)
this->remove_module(m->name()); this->remove_module(m->name());
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propagate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar() and not s.packed()) if(s.broadcasted() and not s.scalar() and not s.packed())
return true; return true;
...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp) add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include) target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx) target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include) rocm_install_targets(TARGETS migraphx_py INCLUDE include)
......
...@@ -26,11 +26,12 @@ ...@@ -26,11 +26,12 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename); MIGRAPHX_PY_EXPORT program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -64,7 +64,7 @@ static dynamic_loader py_lib() ...@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return lib; return lib;
} }
program load_py(const std::string& filename) MIGRAPHX_PY_EXPORT program load_py(const std::string& filename)
{ {
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py"); static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename); return f(filename);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp> #include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
}); });
auto s = add_zero_point->get_shape(); auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant); auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
std::vector<int> max_data(s.elements(), max_quant); auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto min_arg = m.add_literal(literal(s, min_data)); auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction( m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate); ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
} }
......
...@@ -50,13 +50,14 @@ struct shape_impl ...@@ -50,13 +50,14 @@ struct shape_impl
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l) shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true) : m_type(t), m_lens(std::move(l)), m_standard(true)
{ {
assert(t != shape::tuple_type); assert(t != shape::tuple_type);
this->calculate_strides(); this->calculate_strides();
assert(m_lens.size() == m_strides.size());
} }
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s)) : m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{ {
...@@ -151,6 +152,22 @@ struct shape_impl ...@@ -151,6 +152,22 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const std::vector<std::size_t> min_lens() const
{ {
std::vector<std::size_t> ret(m_dyn_dims.size()); std::vector<std::size_t> ret(m_dyn_dims.size());
...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t) ...@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
} }
MIGRAPHX_THROW("Invalid type"); MIGRAPHX_THROW("Invalid type");
} }
std::string shape::cpp_type(shape::type_t t) std::string shape::cpp_type(shape::type_t t)
{ {
switch(t) switch(t)
...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t) ...@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {} shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {} shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l) shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l))) : impl(std::make_shared<shape_impl>(t, std::move(l)))
{ {
} }
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s))) : impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{ {
...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const ...@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
if(this->standard()) if(this->standard())
return i; return i;
else
{ return impl->get_index(i);
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
} }
std::vector<std::size_t> shape::multi(std::size_t idx) const std::vector<std::size_t> shape::multi(std::size_t idx) const
......
...@@ -1121,8 +1121,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) ...@@ -1121,8 +1121,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
}; };
}; };
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2); return (dots >= 2 or convs >= 2 or qdots >= 2);
} }
struct find_conv_dot_horiz_fusion struct find_conv_dot_horiz_fusion
...@@ -1136,7 +1137,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1136,7 +1137,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) { auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator()) if(i->get_operator() != j->get_operator())
return false; return false;
if(not contains({"dot", "convolution"}, i->name())) if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true; return true;
auto x = i->inputs()[1]->get_shape().lens(); auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens(); auto y = j->inputs()[1]->get_shape().lens();
...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
return false; return false;
// Check that non-axes match // Check that non-axes match
int axis = 1; int axis = 1;
if(i->name() == "dot") if(i->name() == "dot" or i->name() == "quant_dot")
{ {
axis = x.size() - 1; axis = x.size() - 1;
} }
...@@ -1155,7 +1156,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1155,7 +1156,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2) if(std::distance(start, last) < 2)
return; return;
auto&& name = (*start)->name(); auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name)) if(not contains({"quant_dot", "dot", "convolution"}, name))
return; return;
auto op = (*start)->get_operator(); auto op = (*start)->get_operator();
int group = 1; int group = 1;
...@@ -1170,7 +1171,7 @@ struct find_conv_dot_horiz_fusion ...@@ -1170,7 +1171,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1; int axis = 1;
int concat_axis = 0; int concat_axis = 0;
if(name == "dot") if(name == "dot" or name == "quant_dot")
{ {
axis = int(args.front()->get_shape().lens().size() - 1); axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis; concat_axis = axis;
...@@ -1350,48 +1351,59 @@ struct find_split_reshape ...@@ -1350,48 +1351,59 @@ struct find_split_reshape
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto input = slc->inputs().front();
auto split_outputs = get_splits(input); auto split_outputs = get_splits(input);
if(split_outputs.empty()) if(split_outputs.empty())
{ {
return; return;
} }
// Only want to apply this optimization if each split output is followed by // Find all the reshapes (similar to rsp) that can be simplified
// a contiguous op and a reshape std::vector<instruction_ref> conts;
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) { std::vector<instruction_ref> vec_rsp;
if(i->outputs().size() == 1)
{ // Iterate through slice and contiguous outputs to allow simplifications when
auto cont = i->outputs().front(); // slice is followed by multiple reshapes
return cont->outputs().size() != 1; for(auto& i : split_outputs)
}
return false;
}))
{ {
return; std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
} }
std::vector<instruction_ref> vec_rsp(split_outputs.size()); for(auto& i : conts)
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { {
auto cont = i->outputs().front(); std::copy_if(i->outputs().begin(),
return cont->outputs().front(); i->outputs().end(),
}); std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}
// all outputs are reshape and of the same shape // No simplification needed if there is only one slice -> cont -> reshape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims; if(vec_rsp.size() <= 1)
if(not same_ops(vec_rsp))
{ {
return; return;
} }
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>()); slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
...@@ -1418,16 +1430,67 @@ struct find_split_reshape ...@@ -1418,16 +1430,67 @@ struct find_split_reshape
{ {
rsp_axis = std::distance(rsp_strides.begin(), ait); rsp_axis = std::distance(rsp_strides.begin(), ait);
} }
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { // Calculate reshape output shape
return is->get_shape().lens()[rsp_axis]; // Need to find a reshape such that data represented by instructions in vec_rsp can be
}); // written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); // cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard()) if(not input->get_shape().standard())
...@@ -1438,16 +1501,14 @@ struct find_split_reshape ...@@ -1438,16 +1501,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
m.replace_instruction( m.replace_instruction(
vec_rsp[i], vec_rsp[i],
make_op( make_op(
"slice", "slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), {{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins); rsp_ins);
start += vec_dims[i];
} }
} }
}; };
...@@ -1471,10 +1532,13 @@ struct find_split_transpose ...@@ -1471,10 +1532,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -784,7 +808,7 @@ struct find_transpose_slice ...@@ -784,7 +808,7 @@ struct find_transpose_slice
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 4; i++) for(int i = 0; i < depth; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -48,6 +48,7 @@ struct sqlite_impl ...@@ -48,6 +48,7 @@ struct sqlite_impl
template <class F> template <class F>
void exec(const char* sql, F f) void exec(const char* sql, F f)
{ {
// cppcheck-suppress constParameterPointer
auto callback = [](void* obj, auto... xs) -> int { auto callback = [](void* obj, auto... xs) -> int {
try try
{ {
......
...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot> ...@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)}; MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
} }
void required(const check_shapes& cs) const { cs.not_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
......
...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive> ...@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
} }
// dnnl has some issues with non-packed inputs // dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); } template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); } std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
...@@ -61,7 +61,7 @@ namespace cpu { ...@@ -61,7 +61,7 @@ namespace cpu {
std::string target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
// cppcheck-suppress constParameter // cppcheck-suppress constParameterReference
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
......
...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS ...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device) ...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx) target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes) target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device) migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
...@@ -123,6 +125,7 @@ add_library(migraphx_gpu ...@@ -123,6 +125,7 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
mlir.cpp mlir.cpp
multinomial.cpp multinomial.cpp
no_device.cpp
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
......
...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if(ins.name() == "multibroadcast") if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue; continue;
if(ins.name() == "pointwise") if(ins.name() == "pointwise")
{ {
......
...@@ -115,6 +115,12 @@ struct hiprtc_program ...@@ -115,6 +115,12 @@ struct hiprtc_program
std::string cpp_src = ""; std::string cpp_src = "";
std::string cpp_name = ""; std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs) hiprtc_program(std::vector<hiprtc_src_file> srcs)
{ {
for(auto&& src : srcs) for(auto&& src : srcs)
...@@ -130,6 +136,14 @@ struct hiprtc_program ...@@ -130,6 +136,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path)); include_names.push_back(std::move(src.path));
} }
} }
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(), prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(), cpp_name.c_str(),
headers.size(), headers.size(),
...@@ -137,7 +151,7 @@ struct hiprtc_program ...@@ -137,7 +151,7 @@ struct hiprtc_program
include_names.data()); include_names.data());
} }
void compile(const std::vector<std::string>& options) const void compile(const std::vector<std::string>& options, bool quiet = false) const
{ {
if(enabled(MIGRAPHX_TRACE_HIPRTC{})) if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl; std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
...@@ -148,7 +162,7 @@ struct hiprtc_program ...@@ -148,7 +162,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); }); [](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log(); auto prog_log = log();
if(not prog_log.empty()) if(not prog_log.empty() and not quiet)
{ {
std::cerr << prog_log << std::endl; std::cerr << prog_log << std::endl;
} }
...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()}; return {prog.get_code_obj()};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input;
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param) std::string enum_params(std::size_t count, std::string param)
......
...@@ -91,28 +91,39 @@ __content__ ...@@ -91,28 +91,39 @@ __content__
return replace_string(args_hpp, "__content__", inner); return replace_string(args_hpp, "__content__", inner);
} }
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings() const std::vector<std::string>& compiler_warnings()
{ {
static std::vector<std::string> warnings = {"-Weverything", static std::vector<std::string> warnings = get_compiler_warnings();
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
return warnings; return warnings;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment