Commit 32b83c9c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into inner_bcast_fix

parents 92f5a6cd 434a06cf
......@@ -223,7 +223,7 @@ void program::compile(const std::vector<target>& targets, std::vector<compile_op
// Gather all the target roots
std::unordered_multimap<std::size_t, module_ref> roots;
auto mods = this->get_modules();
for(auto* mod : mods)
for(const auto* mod : mods)
{
for(const auto& ins : *mod)
{
......@@ -548,7 +548,7 @@ std::vector<argument> program::eval(parameter_map params, execution_environment
ins_out[x] = ss.str();
});
ret = generic_eval(*this, contexts, std::move(params), [&](instruction_ref ins, auto f) {
auto& ctx = contexts[ins->get_target_id()];
const auto& ctx = contexts[ins->get_target_id()];
ctx.finish();
std::cout << "Run instruction: " << ins_out.at(ins) << std::endl;
timer t{};
......@@ -624,7 +624,7 @@ std::string get_migraphx_version()
program file version is for the data structure or format of the MXR file. Version should be bumped
if any changes occur to the format of the MXR file.
*/
const int program_file_version = 6;
const int program_file_version = 7;
value program::to_value() const
{
......@@ -728,7 +728,7 @@ static void mod_from_val(module_ref mod,
std::back_inserter(module_inputs),
[&](const value& i) { return map_mods.at(i.to<std::string>()); });
for(auto& smod : module_inputs)
for(const auto& smod : module_inputs)
{
mod_from_val(smod, v, instructions, map_mods);
}
......@@ -1186,7 +1186,7 @@ void program::remove_unused_modules()
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
for(const auto* m : unused)
this->remove_module(m->name());
}
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins)
bool skip_propagate(instruction_ref ins)
{
if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front());
return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar() and not s.packed())
return true;
......@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false;
}
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const
{
......
......@@ -24,6 +24,7 @@
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
add_library(migraphx_py py_loader.cpp)
migraphx_generate_export_header(migraphx_py)
target_include_directories(migraphx_py PRIVATE include)
target_link_libraries(migraphx_py PUBLIC migraphx)
rocm_install_targets(TARGETS migraphx_py INCLUDE include)
......
......@@ -26,11 +26,12 @@
#include <migraphx/config.hpp>
#include <migraphx/program.hpp>
#include <migraphx/py/export.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
program load_py(const std::string& filename);
MIGRAPHX_PY_EXPORT program load_py(const std::string& filename);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -64,7 +64,7 @@ static dynamic_loader py_lib()
return lib;
}
program load_py(const std::string& filename)
MIGRAPHX_PY_EXPORT program load_py(const std::string& filename)
{
static auto f = py_lib().get_function<program(const std::string&)>("migraphx_load_py");
return f(filename);
......
......@@ -28,6 +28,7 @@
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -61,13 +62,10 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
auto s = add_zero_point->get_shape();
auto min_arg = m.add_literal(literal{shape{s.type()}, {min_quant}});
auto max_arg = m.add_literal(literal{shape{s.type()}, {max_quant}});
auto saturate = insert_common_op(m, ins, make_op("clip"), {add_zero_point, min_arg, max_arg});
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
......
......@@ -50,13 +50,14 @@ struct shape_impl
{
assert(t != shape::tuple_type);
}
shape_impl(shape::type_t t, std::vector<std::size_t> l)
: m_type(t), m_lens(std::move(l)), m_standard(true)
{
assert(t != shape::tuple_type);
this->calculate_strides();
assert(m_lens.size() == m_strides.size());
}
shape_impl(shape::type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: m_type(t), m_lens(std::move(l)), m_strides(std::move(s))
{
......@@ -151,6 +152,22 @@ struct shape_impl
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::size_t get_index(size_t i) const
{
std::size_t result = 0;
std::size_t s = 1;
for(auto k : migraphx::reverse(migraphx::range(m_lens.size())))
{
std::size_t stride = m_strides[k];
std::size_t len = m_lens[k];
std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
std::vector<std::size_t> min_lens() const
{
std::vector<std::size_t> ret(m_dyn_dims.size());
......@@ -213,6 +230,7 @@ std::string shape::name(shape::type_t t)
}
MIGRAPHX_THROW("Invalid type");
}
std::string shape::cpp_type(shape::type_t t)
{
switch(t)
......@@ -229,10 +247,12 @@ std::string shape::cpp_type(shape::type_t t)
shape::shape() : impl(shape_impl::default_shape()) {}
shape::shape(type_t t) : impl(std::make_shared<shape_impl>(t)) {}
shape::shape(type_t t, std::vector<std::size_t> l)
: impl(std::make_shared<shape_impl>(t, std::move(l)))
{
}
shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
: impl(std::make_shared<shape_impl>(t, std::move(l), std::move(s)))
{
......@@ -358,21 +378,8 @@ std::size_t shape::index(std::size_t i) const
assert(this->lens().size() == this->strides().size());
if(this->standard())
return i;
else
{
std::size_t s = 1;
std::size_t result = 0;
for(std::size_t j = 0; j < this->lens().size(); j++)
{
const std::size_t k = this->lens().size() - j - 1;
const std::size_t stride = this->strides()[k];
const std::size_t len = this->lens()[k];
const std::size_t idx = (i % (s * len)) / s;
result += stride * idx;
s *= len;
}
return result;
}
return impl->get_index(i);
}
std::vector<std::size_t> shape::multi(std::size_t idx) const
......
......@@ -1121,8 +1121,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot"));
auto qdots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("quant_dot"));
auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution"));
return (dots >= 2 or convs >= 2);
return (dots >= 2 or convs >= 2 or qdots >= 2);
}
struct find_conv_dot_horiz_fusion
......@@ -1136,7 +1137,7 @@ struct find_conv_dot_horiz_fusion
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
if(not contains({"dot", "convolution"}, i->name()))
if(not contains({"quant_dot", "dot", "convolution"}, i->name()))
return true;
auto x = i->inputs()[1]->get_shape().lens();
auto y = j->inputs()[1]->get_shape().lens();
......@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
return false;
// Check that non-axes match
int axis = 1;
if(i->name() == "dot")
if(i->name() == "dot" or i->name() == "quant_dot")
{
axis = x.size() - 1;
}
......@@ -1155,7 +1156,7 @@ struct find_conv_dot_horiz_fusion
if(std::distance(start, last) < 2)
return;
auto&& name = (*start)->name();
if(not contains({"dot", "convolution"}, name))
if(not contains({"quant_dot", "dot", "convolution"}, name))
return;
auto op = (*start)->get_operator();
int group = 1;
......@@ -1170,7 +1171,7 @@ struct find_conv_dot_horiz_fusion
start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); });
int axis = 1;
int concat_axis = 0;
if(name == "dot")
if(name == "dot" or name == "quant_dot")
{
axis = int(args.front()->get_shape().lens().size() - 1);
concat_axis = axis;
......@@ -1350,48 +1351,59 @@ struct find_split_reshape
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto input = slc->inputs().front();
auto split_outputs = get_splits(input);
if(split_outputs.empty())
{
return;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
// Find all the reshapes (similar to rsp) that can be simplified
std::vector<instruction_ref> conts;
std::vector<instruction_ref> vec_rsp;
// Iterate through slice and contiguous outputs to allow simplifications when
// slice is followed by multiple reshapes
for(auto& i : split_outputs)
{
return;
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
}
std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
auto cont = i->outputs().front();
return cont->outputs().front();
});
for(auto& i : conts)
{
std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}
// all outputs are reshape and of the same shape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims;
if(not same_ops(vec_rsp))
// No simplification needed if there is only one slice -> cont -> reshape
if(vec_rsp.size() <= 1)
{
return;
}
// ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0];
auto axis = axes[0];
auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size
......@@ -1418,16 +1430,67 @@ struct find_split_reshape
{
rsp_axis = std::distance(rsp_strides.begin(), ait);
}
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) {
return is->get_shape().lens()[rsp_axis];
});
// Calculate reshape output shape
// Need to find a reshape such that data represented by instructions in vec_rsp can be
// written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard())
......@@ -1438,16 +1501,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
m.replace_instruction(
vec_rsp[i],
make_op(
"slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}),
{{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins);
start += vec_dims[i];
}
}
};
......@@ -1471,10 +1532,13 @@ struct find_split_transpose
{
return;
}
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front();
});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose
{
auto matcher() const
......@@ -784,7 +808,7 @@ struct find_transpose_slice
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 4; i++)
for(int i = 0; i < depth; i++)
{
match::find_matches(m,
find_where_op{},
......@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{},
find_nested_concat{},
find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m);
......
......@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those
......@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
......
......@@ -48,6 +48,7 @@ struct sqlite_impl
template <class F>
void exec(const char* sql, F f)
{
// cppcheck-suppress constParameterPointer
auto callback = [](void* obj, auto... xs) -> int {
try
{
......
......@@ -43,7 +43,11 @@ struct dnnl_gemm : dnnl_extend_op<dnnl_gemm, dnnl::matmul, op::dot>
MIGRAPHX_DNNL_PREFIX(ARG_BIAS)};
}
void required(const check_shapes& cs) const { cs.not_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.not_broadcasted();
}
dnnl::matmul::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
......
......@@ -400,7 +400,11 @@ struct dnnl_extend_op : dnnl_op<Derived, Primitive>
}
// dnnl has some issues with non-packed inputs
void required(const check_shapes& cs) const { cs.packed_or_broadcasted(); }
template <class T>
void required(const check_shapes<T>& cs) const
{
cs.packed_or_broadcasted();
}
std::string name() const { return "dnnl::" + op.name(); }
shape compute_shape(std::vector<shape> inputs) const
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
......@@ -61,7 +61,7 @@ namespace cpu {
std::string target::name() const { return "cpu"; }
// cppcheck-suppress constParameter
// cppcheck-suppress constParameterReference
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options&) const
{
auto& ctx = any_cast<context>(gctx);
......
......@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS})
......@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
......@@ -123,6 +125,7 @@ add_library(migraphx_gpu
lrn.cpp
mlir.cpp
multinomial.cpp
no_device.cpp
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
......
......@@ -331,7 +331,7 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
if(ins.name() == "multibroadcast")
if(contains({"multibroadcast", "contiguous"}, ins.name()))
continue;
if(ins.name() == "pointwise")
{
......
......@@ -115,6 +115,12 @@ struct hiprtc_program
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<hiprtc_src_file> srcs)
{
for(auto&& src : srcs)
......@@ -130,6 +136,14 @@ struct hiprtc_program
include_names.push_back(std::move(src.path));
}
}
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
......@@ -137,7 +151,7 @@ struct hiprtc_program
include_names.data());
}
void compile(const std::vector<std::string>& options) const
void compile(const std::vector<std::string>& options, bool quiet = false) const
{
if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
......@@ -148,7 +162,7 @@ struct hiprtc_program
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty())
if(not prog_log.empty() and not quiet)
{
std::cerr << prog_log << std::endl;
}
......@@ -210,6 +224,20 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
return {prog.get_code_obj()};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
......@@ -323,6 +351,29 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
src_compiler compiler;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
compiler.flags =
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src;
src_file input;
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try
{
compiler.compile({input});
return true;
}
catch(...)
{
return false;
}
}
#endif // MIGRAPHX_USE_HIPRTC
std::string enum_params(std::size_t count, std::string param)
......
......@@ -91,28 +91,39 @@ __content__
return replace_string(args_hpp, "__content__", inner);
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = {"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment