Commit 8ae761df authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_contiguous' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_unsqueeze

parents 38a196f6 2cf7ae45
......@@ -55,6 +55,7 @@ add_library(migraphx
insert_pad.cpp
instruction.cpp
json.cpp
layout_nhwc.cpp
load_save.cpp
make_op.cpp
module.cpp
......@@ -144,6 +145,7 @@ register_migraphx_ops(
if_op
im2col
isnan
layout
leaky_relu
less
load
......
......@@ -59,6 +59,8 @@ void auto_contiguous::apply(module& m) const
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() == "layout")
continue;
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
......
......@@ -42,6 +42,13 @@ static bool try_compute_shape(instruction_ref ins,
try
{
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
......@@ -133,14 +140,20 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
}
}
// Perform evaluations in parallel
// Perform static contiguous evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
// compute the output contiguous shape from the previous instruction shape
shape computed_shape = c.compute_shape({prev->get_shape()});
const std::vector<argument>& prev_eval = {prev->eval()};
// prev_eval should not be used in make_compute_output_shape() as computed_shape is static
auto co_shape = make_compute_output_shape(pack(c, computed_shape, prev_eval));
literals[i] = c.compute(co_shape, prev_eval);
});
// Replace static contiguous operations with a literal
for(size_t i = 0; i < const_instructions.size(); i++)
{
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
......
......@@ -56,6 +56,8 @@ static void create_pointwise_modules(module_pass_manager& mpm)
{
if(not ins->get_operator().attributes().get("pointwise", false))
continue;
if(ins->get_operator().name() == "layout")
continue;
assert(ins->get_operator().attributes().contains("point_op"));
auto* pm = mpm.create_module(mpm.get_module().name() + ":pointwise" + std::to_string(n++));
pm->set_bypass();
......
......@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraphx/permutation.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
......@@ -232,6 +233,19 @@ struct check_shapes
return *this;
}
/*!
* Check all shapes are packed with certain layouts
*/
const check_shapes&
packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
{
if(not this->all_of([&](const shape& s) {
return s.packed() and contains(layouts, find_permutation(s));
}))
MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
return *this;
}
/*!
* Check all shapes are packed or broadcasted.
*/
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
/**
* Transform convolutions to nhwc
*/
struct layout_nhwc
{
std::string name() const { return "layout_nhwc"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
......@@ -111,21 +111,7 @@ struct literal : raw_data<literal>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
else
{
auto it = start;
m_shape.visit_type([&](auto as) {
auto output = make_view(m_shape, as.from(buffer.get()));
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = *it; // NOLINT(bugprone-signed-char-misuse)
it++;
});
});
}
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
}
};
......
......@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,19 +43,31 @@ namespace op {
struct contiguous
{
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();
if(s0.dynamic())
{
return s0;
}
else
{
if(s0.standard())
{
return inputs.front();
}
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
assert(output_shape.standard());
argument result{output_shape};
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#include <migraphx/config.hpp>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct layout : unary<layout>
{
std::vector<int64_t> permutation;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.permutation, "permutation"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).only_dims(permutation.size());
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return shape::from_permutation(t, lens, permutation);
}
auto apply() const
{
return [](auto x) { return x; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_LAYOUT_HPP
......@@ -31,6 +31,10 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the standard shape indices.
* Will iterate using standard strides if given a non-standard shape.
*/
template <class F>
void shape_for_each(const migraphx::shape& s, F f)
{
......@@ -51,7 +55,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
File mode changed from 100644 to 100755
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
return ins->name() == "convolution" and ins->get_shape().lens().size() == 4;
});
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result.insert(m.replace_instruction(output, layout));
}
return result;
}
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "convolution")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue;
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -26,6 +26,9 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -55,12 +58,12 @@ struct parse_split : op_parser<parse_split>
{
literal s = parser.parse_value(info.attributes.at("split"));
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
}
}
else if(args.size() == 2)
{
auto s = args[1]->eval();
check_arg_empty(s, "Split: dynamic shape is not supported");
s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });
}
// no split attribute, input is equally divided
else
......@@ -74,6 +77,15 @@ struct parse_split : op_parser<parse_split>
vec_splits.resize(info.num_outputs, dl);
}
if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
static_cast<int64_t>(lens[tuned_axis]))
{
MIGRAPHX_THROW(
"PARSE_SPLIT: sum of split attribute unequal to dim size of axis! tuned axis:" +
std::to_string(lens[tuned_axis]) + " Output " + to_string_range(vec_splits) +
" Rank " + std::to_string(n_rank) + " Len outs " + to_string_range(lens));
}
std::vector<instruction_ref> ret_ins;
int64_t start = 0;
for(auto sl : vec_splits)
......
......@@ -51,7 +51,18 @@ struct dnnl_binary : dnnl_op<dnnl_binary, dnnl::binary>
auto r = s0;
if(s0 != s1 or not s0.packed())
{
r = shape{s0.type(), s0.lens()};
if(s0.packed() != s1.packed())
{
r = s0.packed() ? s0 : s1;
}
else if(s0.broadcasted() != s1.broadcasted())
{
r = s0.broadcasted() ? s1.with_lens(s0.lens()) : s0.with_lens(s0.lens());
}
else
{
r = {s0.type(), s0.lens()};
}
}
// Call to get_primitive to make sure an algo is available
this->get_primitive(this->to_memory_desc(r, inputs));
......
......@@ -43,9 +43,9 @@ struct dnnl_convolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1 and op.group > 1)
{
// TODO: Add support for transposed weights
......
......@@ -37,9 +37,9 @@ struct dnnl_deconvolution
return {MIGRAPHX_DNNL_PREFIX(ARG_SRC), MIGRAPHX_DNNL_PREFIX(ARG_WEIGHTS)};
}
shape adjust_shape(const shape& x, int i) const
shape adjust_shape(const shape& x, int i, const shape& output) const
{
auto s = base_adjust_shape(x);
auto s = base_adjust_shape(x, output);
if(i == 1)
{
// The input and output channels are flipped for dnnl
......
......@@ -167,7 +167,7 @@ struct dnnl_op : auto_register_op<Derived>
std::iota(result.begin(), result.end(), MIGRAPHX_DNNL_PREFIX(ARG_SRC_0));
return result;
}
shape base_adjust_shape(const shape& s) const
shape base_adjust_shape(const shape& s, const shape& output) const
{
if(s.broadcasted())
{
......@@ -183,7 +183,8 @@ struct dnnl_op : auto_register_op<Derived>
else
return len;
});
return shape{s.type(), lens};
// Use the permutation of the output
return output.with_lens(s.type(), lens);
}
return s;
}
......@@ -204,7 +205,10 @@ struct dnnl_op : auto_register_op<Derived>
i++;
}
}
shape adjust_shape(const shape& s, int) const { return base_adjust_shape(s); }
shape adjust_shape(const shape& s, int, const shape& output) const
{
return base_adjust_shape(s, output);
}
std::vector<int> create_arg_map(std::size_t input_size) const
{
const auto& self = static_cast<const Derived&>(*this);
......@@ -224,12 +228,12 @@ struct dnnl_op : auto_register_op<Derived>
const auto& self = static_cast<const Derived&>(*this);
std::unordered_map<int, dnnl::memory::desc> result;
result[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size()));
to_dnnl_memory_desc(self.adjust_shape(output_shape, inputs.size(), output_shape));
auto m = create_arg_map(inputs.size());
assert(m.size() >= inputs.size());
for(int i = 0; i < inputs.size(); i++)
{
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i));
result[m[i]] = to_dnnl_memory_desc(self.adjust_shape(inputs[i], i, output_shape));
}
return result;
}
......
......@@ -32,7 +32,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
{
std::string name() const { return "dnnl::reorder"; }
shape adjust_shape(const shape& x, int) const { return x; }
shape adjust_shape(const shape& x, int, const shape&) const { return x; }
shape compute_shape(const std::vector<shape>& inputs) const
{
......
......@@ -33,6 +33,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
......@@ -82,6 +83,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
auto_contiguous{},
simplify_reshapes{},
......
......@@ -83,6 +83,7 @@ add_library(migraphx_gpu
compile_gen.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_miopen.cpp
compiler.cpp
device_name.cpp
fuse_mlir.cpp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment