Commit 00ef197c authored by Shiv's avatar Shiv
Browse files

add conv optim

parents 18e96032 673ca71c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/tensor_view.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Output, class T, class Padding, class Stride>
void convolution(Output output, T input, T weights, Padding padding, Stride stride, int group)
{
auto output_shape = output.get_shape();
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc += input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -24,9 +24,12 @@ ...@@ -24,9 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -201,6 +204,37 @@ struct convolution ...@@ -201,6 +204,37 @@ struct convolution
check_attribute_size(); check_attribute_size();
return stride.size(); return stride.size();
} }
argument compute(shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> new_padding;
if(padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
new_padding =
padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), new_padding, stride, dilation);
}
else
{
new_padding = padding;
if(output_shape.dynamic())
{
output_shape =
normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
migraphx::convolution(output, input, weights, new_padding, stride, group);
});
return result;
}
}; };
} // namespace op } // namespace op
......
...@@ -25,8 +25,10 @@ ...@@ -25,8 +25,10 @@
#define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_QUANT_CONVOLUTION_HPP
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -114,6 +116,17 @@ struct quant_convolution ...@@ -114,6 +116,17 @@ struct quant_convolution
check_attribute_size(); check_attribute_size();
return stride.size(); return stride.size();
} }
argument compute(shape output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
visit_all(args[0], args[1])([&](auto input, auto weights) {
migraphx::convolution(output, input, weights, padding, stride, group);
});
});
return result;
}
}; };
} // namespace op } // namespace op
......
...@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y) ...@@ -52,8 +52,9 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights() auto conv_const_weights()
{ {
return match::name("convolution")(match::used_once(), return match::name("convolution")(
match::args(match::any(), match::is_constant().bind("w"))); match::used_once(),
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
} }
auto reduction() { return match::name_contains("reduce"); } auto reduction() { return match::name_contains("reduce"); }
...@@ -268,6 +269,32 @@ struct find_dot_add ...@@ -268,6 +269,32 @@ struct find_dot_add
} }
}; };
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")),
match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1239,6 +1266,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1239,6 +1266,7 @@ void simplify_algebra::apply(module& m) const
find_neg_unit_ops{}, find_neg_unit_ops{},
find_zero_ops{}, find_zero_ops{},
find_dot_add{}, find_dot_add{},
find_conv_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP #define MIGRAPHX_GUARD_RTGLIB_GPU_CONVOLUTION_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
......
...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs) ...@@ -132,109 +132,6 @@ auto visit_quantize(T&& x, Ts&&... xs)
}; };
} }
template <class Op>
struct ref_convolution : auto_register_op<ref_convolution<Op>>
{
ref_convolution() = default;
ref_convolution(Op pop) : op(std::move(pop)) {}
Op op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> padding;
if(op.padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
padding =
op.padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, op.stride, op.dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), padding, op.stride, op.dilation);
}
else
{
padding = op.padding;
if(output_shape.dynamic())
{
output_shape =
op.normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
auto in_lens = input.get_shape().lens();
auto wei_lens = weights.get_shape().lens();
auto wei_n = wei_lens[0];
auto wei_c = wei_lens[1];
std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto w = idx_o[1];
auto n_dim = idx_o.size();
std::vector<std::ptrdiff_t> win_start;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
std::ptrdiff_t(padding[d_2]));
}
const auto group_id = w / (wei_n / op.group);
shape win_shape{output_shape.type(), win_size};
double acc = 0.0;
shape_for_each(win_shape, [&](auto idx_win) {
auto k = idx_win[0];
const auto in_ch = group_id * wei_c + k;
std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
idx[1] = in_ch;
std::transform(idx_win.begin() + 1,
idx_win.end(),
win_start.begin(),
idx.begin() + 2,
[](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
idx_wei[0] = w;
std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
std::equal(idx.begin(),
idx.end(),
in_lens.begin(),
in_lens.end(),
std::less<std::ptrdiff_t>{}))
{
acc +=
input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
}
});
output[i] = acc;
});
});
return result;
}
};
struct ref_im2col struct ref_im2col
{ {
op::im2col op; op::im2col op;
...@@ -564,11 +461,8 @@ struct ref_apply ...@@ -564,11 +461,8 @@ struct ref_apply
void init() void init()
{ {
apply_map["convolution"] = extend_op<ref_convolution<op::convolution>, op::convolution>(); apply_map["dot"] = extend_op<ref_gemm, op::dot>();
apply_map["dot"] = extend_op<ref_gemm, op::dot>(); apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_dot"] = extend_op<ref_quant_gemm, op::quant_dot>();
apply_map["quant_convolution"] =
extend_op<ref_convolution<op::quant_convolution>, op::quant_convolution>();
apply_map["im2col"] = extend_op<ref_im2col, op::im2col>(); apply_map["im2col"] = extend_op<ref_im2col, op::im2col>();
apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>(); apply_map["logsoftmax"] = extend_op<ref_softmax<op::logsoftmax>, op::logsoftmax>();
apply_map["lrn"] = extend_op<ref_lrn, op::lrn>(); apply_map["lrn"] = extend_op<ref_lrn, op::lrn>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment