Commit 9b929d4e authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test

parents c4b1102e 4394e9b3
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -56,13 +57,21 @@ struct argmax ...@@ -56,13 +57,21 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto lens = inputs[0].lens(); const auto& s0 = inputs[0];
if(s0.dynamic())
{
auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0};
return {shape::int64_type, dyn_dims};
}
else
{
auto lens = s0.lens();
lens[axis] = 1; lens[axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
} }
}
template <class T> template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
...@@ -79,19 +88,18 @@ struct argmax ...@@ -79,19 +88,18 @@ struct argmax
max_index = i; max_index = i;
} }
} }
return max_index; return max_index;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto batch_item_num = args.front().get_shape().lens()[axis]; auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(dyn_out.computed_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i); auto data_idx = dyn_out.computed_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num); output[i] = this->calc_argmax(input, data_idx, batch_item_num);
}); });
}); });
......
...@@ -128,8 +128,8 @@ struct broadcast ...@@ -128,8 +128,8 @@ struct broadcast
{ {
MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis " MIGRAPHX_THROW("BROADCAST_2in: s0 length doesn't match with static s1 axis "
"dimension length (" + "dimension length (" +
migraphx::to_string(s0.dyn_dims()[0]) + migraphx::to_string(s0.lens()[0]) +
" != " + migraphx::to_string(s1.dyn_dims()[axis]) + ")"); " != " + migraphx::to_string(s1.lens()[axis]) + ")");
} }
std::vector<size_t> bcast_strides(s1.ndim(), 0); std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,19 +43,27 @@ namespace op { ...@@ -42,19 +43,27 @@ namespace op {
struct contiguous struct contiguous
{ {
std::string name() const { return "contiguous"; } std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.front().standard()) auto s0 = inputs.front();
return inputs.front(); if(s0.dynamic() or s0.standard())
auto lens = inputs.at(0).lens(); {
auto t = inputs.at(0).type(); return s0;
}
else
{
const auto& lens = s0.lens();
auto t = s0.type();
return {t, lens}; return {t, lens};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
assert(output_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{output_shape}; argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()); output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gemm.hpp> #include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,41 +39,69 @@ struct dot ...@@ -38,41 +39,69 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type().has(2); check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(not std::all_of( if(not std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.ndim() >= 2; }))
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accepts operands with 2 or more dimensions ");
} }
if(a.dynamic() or b.dynamic())
// only handle the case that the batch size of a and b are the same {
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if(not std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static outer dimensions of A and B mismatch: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.ndim() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.ndim() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static inner dimensions do not match: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result = argument{output_shape}; argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
......
...@@ -55,17 +55,47 @@ struct flatten ...@@ -55,17 +55,47 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this, true}.has(1);
auto&& lens = inputs.front().lens(); auto s = inputs[0];
auto x = if(s.dynamic())
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); {
auto y = auto min_lens = s.min_lens();
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); auto max_lens = s.max_lens();
return {inputs.at(0).type(), {x, y}}; auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = {
std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(),
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = {
std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}};
}
else
{
check_shapes{inputs, *this}.standard();
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {s.type(), {x, y}};
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#include <migraphx/config.hpp>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct layout : unary<layout>
{
std::vector<int64_t> permutation;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.permutation, "permutation"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).only_dims(permutation.size());
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return shape::from_permutation(t, lens, permutation);
}
auto apply() const
{
return [](auto x) { return x; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_LAYOUT_HPP
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/dyn_output.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -49,6 +49,9 @@ struct pooling ...@@ -49,6 +49,9 @@ struct pooling
bool ceil_mode = false; bool ceil_mode = false;
int lp_order = 2; int lp_order = 2;
// Global pooling with dynamic shape input
bool dyn_global = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -57,7 +60,8 @@ struct pooling ...@@ -57,7 +60,8 @@ struct pooling
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths"), f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"), f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order")); f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
} }
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
...@@ -65,51 +69,111 @@ struct pooling ...@@ -65,51 +69,111 @@ struct pooling
void check_attribute_size() const void check_attribute_size() const
{ {
if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() != lengths.size()) (not dyn_global and stride.size() != lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
value attributes() const { return {{"normalize_padding", "padding"}}; } value attributes() const { return {{"normalize_padding", "padding"}}; }
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
std::vector<std::size_t> output_lens{};
for(size_t i = 0; i < kdims; ++i)
{
if(input_lens[i + 2] == 0)
{
// handle opt = 0
output_lens.push_back(0);
}
else
{
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
assert(input_lens[i + 2] + padding_factor >= lengths[i]);
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
std::size_t len =
(ceil_mode)
? dim_size / stride[i] + static_cast<std::size_t>((dim_size % stride[i] !=
0)) // ceil uint divide
: dim_size / stride[i]; // floor divide
output_lens.push_back(len + 1);
}
}
return output_lens;
}
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
check_attribute_size();
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size(); auto padding_size = padding.size();
if(input_size != padding_size / 2 + 2 and input_size != padding_size + 2) size_t kdims = input.ndim() - 2;
if(input.ndim() != padding_size / 2 + 2 and input.ndim() != padding_size + 2)
{ {
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); if(input.dynamic())
for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size; auto input_dyn_dims = input.dyn_dims();
auto padding_factor = 2 * padding[i]; std::vector<shape::dynamic_dimension> output_dyn_dims(input_dyn_dims.begin(),
if(padding_size == 2 * kdims) input_dyn_dims.begin() + 2);
padding_factor = padding[i] + padding[i + kdims]; if(dyn_global)
dim_size = input_lens[i + 2] + padding_factor - lengths[i]; {
assert(dim_size >= 0); for(size_t i = 0; i < kdims; ++i)
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) {
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); output_dyn_dims.push_back(shape::dynamic_dimension{1, 1, 1});
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
} }
return inputs[0].with_lens(output_lens); return {input.type(), output_dyn_dims};
}
else
{
auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims);
auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims);
auto opt_spatial_dims = calc_spatial_dim_out(input.opt_lens(), kdims);
for(size_t i = 0; i < kdims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {input.type(), output_dyn_dims};
} }
}
else
{
auto input_lens = input.lens();
size_t kdims() const std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if(dyn_global)
{ {
check_attribute_size(); for(size_t i = 0; i < kdims; ++i)
return stride.size(); {
output_lens.push_back(1);
}
return {input.type(), output_lens};
}
else
{
auto output_spatial_lens = calc_spatial_dim_out(input_lens, kdims);
output_lens.insert(
output_lens.end(), output_spatial_lens.begin(), output_spatial_lens.end());
return inputs[0].with_lens(output_lens);
}
}
} }
struct lpnorm_pool struct lpnorm_pool
...@@ -158,7 +222,11 @@ struct pooling ...@@ -158,7 +222,11 @@ struct pooling
}; };
template <class Type, class Out, class In, class Op> template <class Type, class Out, class In, class Op>
void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const void calc_pooling(const shape& output_shape,
Out& output,
const In& input,
const std::vector<std::size_t>& kernel_dims,
Op op) const
{ {
auto in_s = input.get_shape(); auto in_s = input.get_shape();
auto in_lens = in_s.lens(); auto in_lens = in_s.lens();
...@@ -172,7 +240,7 @@ struct pooling ...@@ -172,7 +240,7 @@ struct pooling
auto d_2 = dim - 2; auto d_2 = dim - 2;
int start = int start =
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]); static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]);
int end = std::min(start + lengths[d_2], in_lens[dim]); int end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0); start = std::max(start, 0);
win_start.push_back(start); win_start.push_back(start);
win_size.push_back(end - start); win_size.push_back(end - start);
...@@ -198,21 +266,32 @@ struct pooling ...@@ -198,21 +266,32 @@ struct pooling
}); });
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> kernel_dims;
if(dyn_global)
{ {
argument result{output_shape}; kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
}
else
{
kernel_dims = this->lengths;
}
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
switch(mode) switch(mode)
{ {
case migraphx::op::pooling_mode::average: case migraphx::op::pooling_mode::average:
calc_pooling<type>(output_shape, output, input, avg_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{});
break; break;
case migraphx::op::pooling_mode::max: case migraphx::op::pooling_mode::max:
calc_pooling<type>(output_shape, output, input, max_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, max_pool{});
break; break;
case migraphx::op::pooling_mode::lpnorm: case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>(output_shape, output, input, lpnorm_pool{lp_order}); calc_pooling<type>(
dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order});
break; break;
} }
}); });
......
...@@ -53,15 +53,15 @@ struct softmax ...@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.at(0).packed()) auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
auto lens = inputs.at(0).lens(); return {s0.type(), s0.lens()};
return {inputs.at(0).type(), lens};
} }
} }
......
...@@ -59,9 +59,8 @@ struct squeeze ...@@ -59,9 +59,8 @@ struct squeeze
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic()) if(input_shape.dynamic())
{ {
std::vector<shape::dynamic_dimension> one_dyn_dims{{1, 1, 0}, {1, 1, 1}};
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not contains(one_dyn_dims, input_shape.dyn_dims()[axis]); return input_shape.dyn_dims()[axis] != 1;
})) }))
{ {
MIGRAPHX_THROW( MIGRAPHX_THROW(
...@@ -70,14 +69,10 @@ struct squeeze ...@@ -70,14 +69,10 @@ struct squeeze
std::vector<shape::dynamic_dimension> dyn_dims = {}; std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty()) if(axes.empty())
{ {
for(auto i : range(input_shape.ndim())) std::copy_if(input_shape.dyn_dims().cbegin(),
{ input_shape.dyn_dims().cend(),
auto dd = input_shape.dyn_dims()[i]; std::back_inserter(dyn_dims),
if(not contains(one_dyn_dims, dd)) [&](auto dd) { return dd != 1; });
{
dyn_dims.push_back(dd);
}
}
} }
else else
{ {
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++) {
std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
return {t, output_lens, output_strides}; return {input.type(), output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -115,6 +115,7 @@ struct program ...@@ -115,6 +115,7 @@ struct program
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -101,6 +101,12 @@ struct shape ...@@ -101,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref& ...@@ -77,14 +77,14 @@ static void update_pooling(const instruction_ref& input, const instruction_ref&
{ {
return; return;
} }
auto kdims = input->get_shape().lens().size() - 2; auto kdims = input->get_shape().ndim() - 2;
if(std::equal(op.padding.begin(), if(std::equal(op.padding.begin(),
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.begin() + kdims, op.padding.begin() + kdims,
op.padding.end())) op.padding.end()))
return; return;
std::vector<int64_t> padding(input->get_shape().lens().size() * 2, 0); std::vector<int64_t> padding(input->get_shape().ndim() * 2, 0);
std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> pads_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> pads_r(op.padding.begin() + kdims, op.padding.end());
op.padding = std::vector<size_t>(kdims * 2, 0); op.padding = std::vector<size_t>(kdims * 2, 0);
......
...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod) ...@@ -302,6 +302,24 @@ void instruction::replace_mod_argument(module_ref old, module_ref new_mod)
std::replace(module_args.begin(), module_args.end(), old, new_mod); std::replace(module_args.begin(), module_args.end(), old, new_mod);
} }
bool instruction::is_undefined() const
{
if(op.name() == "undefined")
{
return true;
}
else if(this->inputs().empty())
{
return false;
}
else
{
return std::all_of(this->inputs().begin(), this->inputs().end(), [](auto arg) {
return arg->is_undefined();
});
}
}
bool instruction::can_eval() const bool instruction::can_eval() const
{ {
if(op.name() == "@literal") if(op.name() == "@literal")
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Predicate>
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
std::vector<instruction_ref> result;
fix([&](auto self, auto ins) {
if(pred(ins))
{
result.push_back(ins);
return;
}
for(auto input : ins->inputs())
self(input);
})(std::prev(m.end()));
return result;
}
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
std::unordered_set<instruction_ref> result;
std::vector<instruction_ref> outputs = find_lasts(m, [](auto ins) {
return ins->name() == "convolution" and ins->get_shape().lens().size() == 4;
});
for(auto output : outputs)
{
auto permutation = find_permutation(output->get_shape());
auto layout = m.insert_instruction(
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result.insert(m.replace_instruction(output, layout));
}
return result;
}
void transform_convolutions(module& m)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "convolution")
continue;
if(ins->get_shape().lens().size() != 4)
continue;
auto v = ins->get_operator().to_value();
if(v.at("group").to<int>() > 1)
continue;
auto args = ins->inputs();
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
});
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
m.replace_instruction(ins, c);
}
}
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "layout")
continue;
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue;
if(contains(output_layouts, ins))
continue;
m.replace_instruction(ins, ins->inputs().front());
}
}
void layout_nhwc::apply(module_pass_manager& mpm) const
{
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions(mpm.get_module());
mpm.run_pass(dead_code_elimination{});
mpm.run_pass(eliminate_contiguous{"contiguous"});
mpm.run_pass(dead_code_elimination{});
remove_layout(mpm.get_module(), output_layouts);
mpm.run_pass(dead_code_elimination{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name) ...@@ -789,6 +789,22 @@ static std::string cpp_var_name(const std::string& name)
return to_c_id("x_" + replace_string(name, ":", "_module_")); return to_c_id("x_" + replace_string(name, ":", "_module_"));
} }
static void print_py_op(std::ostream& os, const operation& op)
{
auto v = op.to_value();
os << "migraphx.op(" << enclose_name(op.name());
auto default_values = make_op(op.name()).to_value();
for(auto&& x : v)
{
auto name = x.get_key();
if(default_values[name] == x)
continue;
os << ", " << name << "=" << to_json_string(x.without_key());
}
os << ")";
}
static void print_make_op(std::ostream& os, const operation& op) static void print_make_op(std::ostream& os, const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -804,6 +820,14 @@ static void print_make_op(std::ostream& os, const operation& op)
os << ")"; os << ")";
} }
static void print_py_shape(std::ostream& os, const migraphx::shape& s)
{
os << "migraphx.shape(" << s.type_string() << ", lens=" << to_json_string(s.lens());
if(not s.standard())
os << ", strides=" << to_json_string(s.strides());
os << ")";
}
static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
{ {
os << "migraphx::shape{migraphx::shape::" << s.type_string(); os << "migraphx::shape{migraphx::shape::" << s.type_string();
...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s) ...@@ -813,6 +837,68 @@ static void print_cpp_shape(std::ostream& os, const migraphx::shape& s)
os << "}"; os << "}";
} }
std::unordered_map<instruction_ref, std::string>
module::print_py(std::ostream& os,
const std::string& mname,
std::unordered_map<instruction_ref, std::string> names) const
{
// cppcheck-suppress variableScope
unsigned long seed = names.size();
auto last = std::prev(this->end());
names = this->print(
[&](auto ins, auto ins_names) {
std::vector<std::string> input_vars;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(input_vars),
[&](auto input) { return cpp_var_name(ins_names.at(input)); });
if(ins != last)
os << cpp_var_name(ins_names.at(ins)) << " = ";
if(ins->name() == "@literal")
{
os << mname << ".add_literal(";
bool use_abs = false;
ins->get_literal().visit([&](auto v) {
use_abs = std::none_of(v.begin(), v.end(), [](auto x) { return x < 0; });
});
// Disable abs for now
use_abs = false;
if(use_abs)
os << "migraphx.abs_literal(";
os << "migraphx.generate_literal(";
print_py_shape(os, ins->get_shape());
os << ", " << seed << ")";
if(use_abs)
os << ")";
os << ")" << std::endl;
seed++;
}
else if(ins->name() == "@param")
{
std::string name = any_cast<builtin::param>(ins->get_operator()).parameter;
os << mname << ".add_parameter(" << enclose_name(name) << ",";
print_py_shape(os, ins->get_shape());
os << ")" << std::endl;
}
else if(ins->name() == "@return")
{
os << mname << ".add_return([" << join_strings(input_vars, ", ") << "])"
<< std::endl;
}
else
{
assert(ins->name().front() != '@');
os << mname << ".add_instruction(";
print_py_op(os, ins->get_operator());
os << ", [" << join_strings(input_vars, ", ") << "]";
os << ")" << std::endl;
}
},
names);
return names;
}
std::unordered_map<instruction_ref, std::string> std::unordered_map<instruction_ref, std::string>
module::print_cpp(std::ostream& os, module::print_cpp(std::ostream& os,
const std::string& mname, const std::string& mname,
...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os, ...@@ -874,6 +960,8 @@ module::print_cpp(std::ostream& os,
return names; return names;
} }
void module::print_py(std::ostream& os) const { this->print_py(os, this->name(), {}); }
void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); } void module::print_cpp(std::ostream& os) const { this->print_cpp(os, this->name(), {}); }
void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) const
......
...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const ...@@ -393,18 +393,31 @@ literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(not t.external_data().empty()) auto type = get_type(t.data_type());
shape tensor_shape(type, dims);
auto external_data = t.external_data();
if(not external_data.empty())
{
const std::string& data_file = external_data.at(0).value();
size_t num_data_fields = external_data.size();
size_t offset = 0;
size_t nbytes = tensor_shape.bytes();
if(num_data_fields > 1) // if offset field is present
{
offset = std::stoul(t.external_data().at(1).value());
}
if(num_data_fields > 2) // if nbytes field is present
{ {
const std::string& data_file = t.external_data().at(0).value(); nbytes = std::stoul(t.external_data().at(2).value());
auto raw_buffer = read_buffer(path + "/" + data_file); }
auto raw_buffer = read_buffer(path + "/" + data_file, offset, nbytes);
std::string s(raw_buffer.begin(), raw_buffer.end()); std::string s(raw_buffer.begin(), raw_buffer.end());
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
auto type = get_type(t.data_type());
return create_literal(type, dims, s.data()); return create_literal(type, dims, s.data());
} }
......
...@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op> ...@@ -57,6 +57,12 @@ struct parse_binary_op : op_parser<parse_binary_op>
parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>(); parser.parse_value(info.attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
if(std::any_of(
args.cbegin(), args.cend(), [](auto a) { return a->get_shape().dynamic(); }))
{
MIGRAPHX_THROW(
"Binary op broadcast attribute not supported for dynamic input shapes");
}
uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>(); uint64_t axis = parser.parse_value(info.attributes.at("axis")).at<uint64_t>();
auto l = info.add_instruction( auto l = info.add_instruction(
make_op("broadcast", make_op("broadcast",
......
...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -47,52 +47,42 @@ struct parse_pooling : op_parser<parse_pooling>
{"GlobalLpPool", "lpnorm"}}; {"GlobalLpPool", "lpnorm"}};
} }
instruction_ref parse(const op_desc& opd, value handle_values(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const const shape& in_shape,
value values) const
{ {
const std::unordered_map<std::string, op::pooling_mode> mode_map = { auto kdims = in_shape.ndim() - 2;
{"max", op::pooling_mode::max}, if(starts_with(opd.onnx_name, "Global"))
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
std::string mode = opd.op_name;
if(not contains(mode_map, mode))
{ {
MIGRAPHX_THROW("onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]"); // if spatial dimensions are dynamic use dyn_global flag
if(in_shape.dynamic() and std::any_of(in_shape.dyn_dims().cbegin() + 2,
in_shape.dyn_dims().cend(),
[](auto dd) { return not dd.is_fixed(); }))
{
values["dyn_global"] = true;
values["lengths"] = std::vector<size_t>();
} }
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}}); else
value values = op.to_value();
auto l0 = args[0];
auto in_lens = l0->get_shape().lens();
assert(in_lens.size() > 2);
auto kdims = in_lens.size() - 2;
if(starts_with(opd.onnx_name, "Global"))
{ {
values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end()); // works with static and fixed dynamic shape
auto m_lens = in_shape.max_lens();
values["lengths"] = std::vector<size_t>(m_lens.begin() + 2, m_lens.end());
}
} }
// does not support ceil_mode
if(contains(info.attributes, "ceil_mode")) if(contains(info.attributes, "ceil_mode"))
{ {
values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i()); values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
} }
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
count_include_pad = info.attributes.at("count_include_pad").i();
}
if(contains(info.attributes, "strides")) if(contains(info.attributes, "strides"))
{ {
values["stride"].clear(); values["stride"].clear();
copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"])); copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides"); check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
} }
if(contains(info.attributes, "kernel_shape")) if(contains(info.attributes, "kernel_shape"))
{ {
values["lengths"].clear(); values["lengths"].clear();
...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -110,6 +100,46 @@ struct parse_pooling : op_parser<parse_pooling>
// ensure pads availabe only when auto_pad is "NOT_SET" // ensure pads availabe only when auto_pad is "NOT_SET"
check_padding_mode(info, "POOLING"); check_padding_mode(info, "POOLING");
return values;
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
std::string mode = opd.op_name;
const std::unordered_map<std::string, op::pooling_mode> mode_map = {
{"max", op::pooling_mode::max},
{"average", op::pooling_mode::average},
{"lpnorm", op::pooling_mode::lpnorm}};
if(not contains(mode_map, mode))
{
MIGRAPHX_THROW(
"PARSE_POOLING: onnx pooling mode must be [\"max\", \"average\", \"lpnorm\"]");
}
operation op = make_op("pooling", {{"mode", mode_map.at(mode)}});
value values = op.to_value();
auto l0 = args[0];
auto in_shape = l0->get_shape();
assert(in_shape.ndim() > 2);
auto kdims = in_shape.ndim() - 2;
values = handle_values(opd, info, in_shape, values);
// count include padding, if count include pad is 1, we always use
// explicit pad
int count_include_pad = 0;
if(contains(info.attributes, "count_include_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW("PARSE_POOLING: count_include_pad attribute is not supported for "
"dynamic input shape");
}
count_include_pad = info.attributes.at("count_include_pad").i();
}
std::vector<int64_t> paddings; std::vector<int64_t> paddings;
float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f); float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
...@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -122,6 +152,13 @@ struct parse_pooling : op_parser<parse_pooling>
} }
if(contains(info.attributes, "auto_pad")) if(contains(info.attributes, "auto_pad"))
{
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: Auto padding pooling with dynamic input shape not supported");
}
else
{ {
values["padding"].clear(); values["padding"].clear();
// return paddings could be empty, then setting to 0 for no padding // return paddings could be empty, then setting to 0 for no padding
...@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -129,9 +166,10 @@ struct parse_pooling : op_parser<parse_pooling>
values, values,
values["lengths"].to_vector<std::size_t>(), values["lengths"].to_vector<std::size_t>(),
{1, 1}, {1, 1},
in_lens, in_shape.lens(),
paddings); paddings);
} }
}
if(paddings.size() != 2 * kdims) if(paddings.size() != 2 * kdims)
{ {
...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -150,6 +188,7 @@ struct parse_pooling : op_parser<parse_pooling>
values["stride"].resize(kdims); values["stride"].resize(kdims);
std::fill_n(values["stride"].begin(), kdims, 1); std::fill_n(values["stride"].begin(), kdims, 1);
} }
// used to calculate the supposed output shape // used to calculate the supposed output shape
std::vector<int64_t> orig_padding = paddings; std::vector<int64_t> orig_padding = paddings;
...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling> ...@@ -159,6 +198,11 @@ struct parse_pooling : op_parser<parse_pooling>
if(not slice_start.empty()) if(not slice_start.empty())
{ {
if(in_shape.dynamic())
{
MIGRAPHX_THROW(
"PARSE_POOLING: asymmetric padding not supported for dynamic input shape");
}
// calculate expected output shape // calculate expected output shape
orig_padding.insert(orig_padding.begin() + kdims, 2, 0); orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
orig_padding.insert(orig_padding.begin(), 2, 0); orig_padding.insert(orig_padding.begin(), 2, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment