Unverified Commit 0ca5e1ce authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Merge branch 'develop' into conv_dot_fuse_reshape_bugfix

parents 3b981d9d df32040d
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/tensor_view.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Output, class T, class Padding, class Stride>
void convolution(Output output, T input, T weights, Padding padding, Stride stride, int group)
{
auto output_shape = output.get_shape();
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc += input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,9 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <utility>
......@@ -210,6 +213,37 @@ struct convolution
check_attribute_size();
return stride.size();
}
argument compute(shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> new_padding;
if(padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
new_padding =
padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), new_padding, stride, dilation);
}
else
{
new_padding = padding;
if(output_shape.dynamic())
{
output_shape =
normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
migraphx::convolution(output, input, weights, new_padding, stride, group);
});
return result;
}
};
} // namespace op
......
......@@ -25,8 +25,10 @@
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
#include <cmath>
#include <utility>
......@@ -114,6 +116,17 @@ struct quant_convolution
check_attribute_size();
return stride.size();
}
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto input, auto weights) {
migraphx::convolution(output, input, weights, padding, stride, group);
});
});
return result;
}
};
} // namespace op
......
......@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
return match::name("convolution")(
match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}
auto reduction() { return match::name_contains("reduce"); }
......@@ -203,7 +204,12 @@ struct find_mul_slice_conv
}
};
// a * (x + b) => a * x + a * b
// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.
struct find_mul_add
{
auto matcher() const
......@@ -268,6 +274,32 @@ struct find_dot_add
}
};
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -1244,6 +1276,7 @@ void simplify_algebra::apply(module& m) const
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{},
find_conv_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
......@@ -762,7 +762,7 @@ struct find_transpose_slice
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
auto preaxis = perm[axis];
// Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
......
......@@ -82,7 +82,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
layout_nhwc{},
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
......
......@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
......
......@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
};
}
template <class Op>
struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
ref_convolution() = default;
ref_convolution(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> padding;
if(op.padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
padding =
op.padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
}
else
{
padding = op.padding;
if(output_shape.dynamic())
{
output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / op.group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc +=
input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
});
return result;
}
};
struct ref_im2col
{
op::im2col op;
......@@ -564,11 +461,8 @@ struct ref_apply
void init()
{
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
......
......@@ -509,6 +509,34 @@ TEST_CASE(simplify_dot_add)
EXPECT(m1 == m2);
}
TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto c = m1.add_literal(migraphx::generate_literal(s, 1));
auto w = m1.add_literal(migraphx::generate_literal(ws, 2));
auto sum = m1.add_instruction(migraphx::make_op("add"), c, x);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), sum, w);
m1.add_instruction(pass_op{}, conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto c = m2.add_literal(migraphx::generate_literal(s, 1));
auto w = m2.add_literal(migraphx::generate_literal(ws, 2));
auto conv1 = m2.add_instruction(migraphx::make_op("convolution"), c, w);
auto conv2 = m2.add_instruction(migraphx::make_op("convolution"), x, w);
auto sum = m2.add_instruction(migraphx::make_op("add"), conv1, conv2);
m2.add_instruction(pass_op{}, sum);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
......
......@@ -1322,6 +1322,46 @@ TEST_CASE(transpose_slice)
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_unsqueeze)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto transpose1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}),
transpose1);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {16}}, {"ends", {24}}}),
transpose1);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {32}}, {"ends", {40}}}),
transpose1);
m1.add_return({slice1, slice2, slice3});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {4, 1024, 96, 64}});
auto unsq =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 4, 1}}}), unsq);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto sq1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto sq2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {5}}}), transpose);
auto sq3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({sq1, sq2, sq3});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_diff_perm)
{
migraphx::module m1;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_add_conv_constant : verify_program<test_add_conv_constant>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
auto x = mm->add_parameter("x", s);
auto c = mm->add_literal(migraphx::generate_literal(s, 1));
auto w = mm->add_literal(migraphx::generate_literal(ws, 2));
auto sum = mm->add_instruction(migraphx::make_op("add"), c, x);
mm->add_instruction(migraphx::make_op("convolution"), sum, w);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment