Unverified Commit dae94657 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jit-reduce-reg

parents b013d991 56c43445
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/gemm.hpp> #include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,41 +39,69 @@ struct dot ...@@ -38,41 +39,69 @@ struct dot
std::string name() const { return "dot"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_type().has(2); check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
const shape& a = inputs.at(0); const shape& a = inputs.at(0);
const shape& b = inputs.at(1); const shape& b = inputs.at(1);
auto t = a.type(); auto t = a.type();
if(not std::all_of( if(not std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.ndim() >= 2; }))
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
{ {
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands"); MIGRAPHX_THROW("DOT: dot only accepts operands with 2 or more dimensions ");
} }
if(a.dynamic() or b.dynamic())
// only handle the case that the batch size of a and b are the same {
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
}
else
{
// only handle the case that all the dimensions except the last two are the same
if(not std::equal( if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend())) a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{ {
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static outer dimensions of A and B mismatch: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
std::size_t dim_0 = a.lens().size() - 2; std::size_t dim_0 = a.ndim() - 2;
std::size_t dim_1 = a.lens().size() - 1; std::size_t dim_1 = a.ndim() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0]) if(a.lens()[dim_1] != b.lens()[dim_0])
{ {
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) + MIGRAPHX_THROW("DOT: static inner dimensions do not match: {" +
"} x {" + to_string_range(b.lens()) + "}"); to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
} }
auto out_lens = a.lens(); auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1]; out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens}; return {t, out_lens};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result = argument{output_shape}; argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])( visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); }); [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result; return result;
......
...@@ -32,14 +32,13 @@ namespace migraphx { ...@@ -32,14 +32,13 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct elu struct elu : unary<elu>
{ {
std::string name() const { return "elu"; }
float alpha = 1; float alpha = 1;
shape compute_shape(std::vector<shape> inputs) const
std::string point_op() const
{ {
check_shapes{inputs, *this}.has(1); return "${function:where}(${0} > 0, ${0}, ${alpha} * (${function:exp}(${0}) - 1))";
return inputs.front();
} }
template <class Self, class F> template <class Self, class F>
...@@ -47,6 +46,11 @@ struct elu ...@@ -47,6 +46,11 @@ struct elu
{ {
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
auto apply() const
{
return [&](auto x) { return x > 0 ? x : alpha * std::expm1(x); };
}
}; };
} // namespace op } // namespace op
......
...@@ -55,17 +55,47 @@ struct flatten ...@@ -55,17 +55,47 @@ struct flatten
std::string name() const { return "flatten"; } std::string name() const { return "flatten"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this, true}.has(1);
auto&& lens = inputs.front().lens(); auto s = inputs[0];
auto x = if(s.dynamic())
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{}); {
auto y = auto min_lens = s.min_lens();
std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{}); auto max_lens = s.max_lens();
return {inputs.at(0).type(), {x, y}}; auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = {
std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(),
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = {
std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate(
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}};
}
else
{
check_shapes{inputs, *this}.standard();
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {s.type(), {x, y}};
}
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -21,50 +21,48 @@ ...@@ -21,50 +21,48 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP #ifndef MIGRAPHX_GUARD_OP_LAYOUT_HPP
#define MIGRAPHX_GUARD_OPERATORS_BATCH_NORM_HPP #define MIGRAPHX_GUARD_OP_LAYOUT_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/op/unary.hpp>
#include <cmath> #include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct batch_norm_inference struct layout : unary<layout>
{ {
float epsilon = 1.0e-6f; std::vector<int64_t> permutation;
float momentum = 0.9f;
std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack( return pack(f(self.permutation, "permutation"));
f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(1).only_dims(permutation.size());
check_shapes{inputs.data(), inputs.data() + 1, *this}.same_ndims(); auto lens = inputs.at(0).lens();
check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape(); auto t = inputs.at(0).type();
return inputs.front(); return shape::from_permutation(t, lens, permutation);
}
auto apply() const
{
return [](auto x) { return x; };
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_OP_LAYOUT_HPP
#endif
...@@ -26,12 +26,13 @@ ...@@ -26,12 +26,13 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/unary.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct leaky_relu struct leaky_relu : unary<leaky_relu>
{ {
float alpha = 0.01; float alpha = 0.01;
...@@ -41,11 +42,13 @@ struct leaky_relu ...@@ -41,11 +42,13 @@ struct leaky_relu
return pack(f(self.alpha, "alpha")); return pack(f(self.alpha, "alpha"));
} }
std::string point_op() const { return "${function:where}(${0} > 0, ${0}, ${alpha} * ${0})"; }
std::string name() const { return "leaky_relu"; } std::string name() const { return "leaky_relu"; }
shape compute_shape(std::vector<shape> inputs) const
auto apply() const
{ {
check_shapes{inputs, *this}.has(1); return [&](auto x) { return x > 0 ? x : x * alpha; };
return inputs.front();
} }
}; };
......
...@@ -26,64 +26,105 @@ ...@@ -26,64 +26,105 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Broadcast multiple dimensions between two tensors.
* Two versions of this operator: one input and two inputs.
* One input version uses output_lens attribute and broadcasts to it.
* Two inputs version broadcasts both inputs to the common shape at evaluation time.
*/
struct multibroadcast struct multibroadcast
{ {
std::vector<std::size_t> output_lens; std::vector<std::size_t> output_lens = {};
// optional attribute
std::vector<shape::dynamic_dimension> output_dyn_dims = {};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.output_lens, "out_lens")); return pack(f(self.output_lens, "out_lens"), f(self.output_dyn_dims, "out_dyn_dims"));
} }
std::string name() const { return "multibroadcast"; } std::string name() const { return "multibroadcast"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0); auto s0 = inputs.at(0);
if(input.lens().empty()) if(s0.max_lens().empty())
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0"); MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should be > 0");
} }
if(input.lens().size() > output_lens.size()) auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(bcast_lens[i + offset] == s0.lens()[i])
{ {
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size"); bcast_strides[i + offset] = s0.strides()[i];
} }
}
return bcast_strides;
};
auto offset = output_lens.size() - input.lens().size(); if(inputs.size() == 1)
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) {
if(s0.lens().size() > output_lens.size())
{ {
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1) MIGRAPHX_THROW("MULTIBROADCAST: input dimensions should <= output size");
}
auto offset = output_lens.size() - s0.lens().size();
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{ {
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) + if(output_lens[i + offset] != s0.lens()[i] and s0.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(s0.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) + "} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!"); "}!");
} }
} }
std::vector<size_t> bcast_strides(output_lens.size(), 0); auto bcast_strides = make_bcast_strides(output_lens, offset);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--) return {t, output_lens, std::move(bcast_strides)};
}
else
{ {
if(output_lens[i + offset] == input.lens()[i]) // two inputs
auto s1 = inputs.at(1);
if(s0.dynamic() or s1.dynamic())
{ {
bcast_strides[i + offset] = input.strides()[i]; if(not output_dyn_dims.empty())
{
return {t, output_dyn_dims};
} }
return {t, compute_broadcasted_dyn_dims(s0, s1)};
} }
return {t, output_lens, bcast_strides}; else
{
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
auto offset = bcast_lens.size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/int_divide.hpp> #include <migraphx/dyn_output.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -49,6 +49,9 @@ struct pooling ...@@ -49,6 +49,9 @@ struct pooling
bool ceil_mode = false; bool ceil_mode = false;
int lp_order = 2; int lp_order = 2;
// Global pooling with dynamic shape input
bool dyn_global = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -57,59 +60,120 @@ struct pooling ...@@ -57,59 +60,120 @@ struct pooling
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.lengths, "lengths"), f(self.lengths, "lengths"),
f(self.ceil_mode, "ceil_mode"), f(self.ceil_mode, "ceil_mode"),
f(self.lp_order, "lp_order")); f(self.lp_order, "lp_order"),
f(self.dyn_global, "dyn_global"));
} }
std::string name() const { return "pooling"; } std::string name() const { return "pooling"; }
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == lengths.size())) (not dyn_global and stride.size() != lengths.size()))
{ {
MIGRAPHX_THROW("POOLING: inconsistent attribute sizes"); MIGRAPHX_THROW("POOLING: inconsistent attribute sizes");
} }
} }
size_t kdims() const
{
check_attribute_size();
return stride.size();
}
value attributes() const { return {{"normalize_padding", "padding"}}; } value attributes() const { return {{"normalize_padding", "padding"}}; }
std::vector<std::size_t> calc_spatial_dim_out(const std::vector<std::size_t>& input_lens,
std::size_t kdims) const
{
std::vector<std::size_t> output_lens{};
for(size_t i = 0; i < kdims; ++i)
{
if(input_lens[i + 2] == 0)
{
// handle opt = 0
output_lens.push_back(0);
}
else
{
std::size_t padding_factor = 2 * padding[i];
if(padding.size() == 2 * kdims)
padding_factor = padding[i] + padding[i + kdims];
assert(input_lens[i + 2] + padding_factor >= lengths[i]);
std::size_t dim_size = input_lens[i + 2] + padding_factor - lengths[i];
std::size_t len =
(ceil_mode)
? dim_size / stride[i] + static_cast<std::size_t>((dim_size % stride[i] !=
0)) // ceil uint divide
: dim_size / stride[i]; // floor divide
output_lens.push_back(len + 1);
}
}
return output_lens;
}
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
check_attribute_size();
const shape& input = inputs.at(0); const shape& input = inputs.at(0);
auto input_lens = input.lens();
size_t kdims = input_lens.size() - 2;
auto input_size = inputs[0].lens().size();
auto padding_size = padding.size(); auto padding_size = padding.size();
if(not(input_size == padding_size / 2 + 2 or input_size == padding_size + 2)) size_t kdims = input.ndim() - 2;
if(input.ndim() != padding_size / 2 + 2 and input.ndim() != padding_size + 2)
{ {
MIGRAPHX_THROW("POOLING: input and attribute size mismatch!"); MIGRAPHX_THROW("POOLING: input and attribute size mismatch!");
} }
std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2); if(input.dynamic())
for(size_t i = 0; i < kdims; i++)
{ {
std::ptrdiff_t dim_size; auto input_dyn_dims = input.dyn_dims();
auto padding_factor = 2 * padding[i]; std::vector<shape::dynamic_dimension> output_dyn_dims(input_dyn_dims.begin(),
if(padding_size == 2 * kdims) input_dyn_dims.begin() + 2);
padding_factor = padding[i] + padding[i + kdims]; if(dyn_global)
dim_size = input_lens[i + 2] + padding_factor - lengths[i]; {
assert(dim_size >= 0); for(size_t i = 0; i < kdims; ++i)
std::size_t len = (ceil_mode) ? ceil_divide<std::ptrdiff_t>(dim_size, stride[i]) {
: floor_divide<std::ptrdiff_t>(dim_size, stride[i]); output_dyn_dims.push_back(shape::dynamic_dimension{1, 1, 1});
output_lens.push_back(std::size_t(std::max<std::ptrdiff_t>(1, len + 1)));
} }
return inputs[0].with_lens(output_lens); return {input.type(), output_dyn_dims};
}
else
{
auto min_spatial_dims = calc_spatial_dim_out(input.min_lens(), kdims);
auto max_spatial_dims = calc_spatial_dim_out(input.max_lens(), kdims);
auto opt_spatial_dims = calc_spatial_dim_out(input.opt_lens(), kdims);
for(size_t i = 0; i < kdims; ++i)
{
output_dyn_dims.push_back(shape::dynamic_dimension{
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]});
}
return {input.type(), output_dyn_dims};
} }
}
else
{
auto input_lens = input.lens();
size_t kdims() const std::vector<std::size_t> output_lens(input_lens.begin(), input_lens.begin() + 2);
// Used for when normalize_compute_shape() is called again at model eval time
// for an originally dynamic shape. Since kernel shape is not used with dyn_global.
if(dyn_global)
{ {
check_attribute_size(); for(size_t i = 0; i < kdims; ++i)
return stride.size(); {
output_lens.push_back(1);
}
return {input.type(), output_lens};
}
else
{
auto output_spatial_lens = calc_spatial_dim_out(input_lens, kdims);
output_lens.insert(
output_lens.end(), output_spatial_lens.begin(), output_spatial_lens.end());
return inputs[0].with_lens(output_lens);
}
}
} }
struct lpnorm_pool struct lpnorm_pool
...@@ -158,7 +222,11 @@ struct pooling ...@@ -158,7 +222,11 @@ struct pooling
}; };
template <class Type, class Out, class In, class Op> template <class Type, class Out, class In, class Op>
void calc_pooling(const shape& output_shape, Out& output, const In& input, Op op) const void calc_pooling(const shape& output_shape,
Out& output,
const In& input,
const std::vector<std::size_t>& kernel_dims,
Op op) const
{ {
auto in_s = input.get_shape(); auto in_s = input.get_shape();
auto in_lens = in_s.lens(); auto in_lens = in_s.lens();
...@@ -172,7 +240,7 @@ struct pooling ...@@ -172,7 +240,7 @@ struct pooling
auto d_2 = dim - 2; auto d_2 = dim - 2;
int start = int start =
static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]); static_cast<int>(idx_o[dim] * stride[d_2]) - static_cast<int>(padding[d_2]);
int end = std::min(start + lengths[d_2], in_lens[dim]); int end = std::min(start + kernel_dims[d_2], in_lens[dim]);
start = std::max(start, 0); start = std::max(start, 0);
win_start.push_back(start); win_start.push_back(start);
win_size.push_back(end - start); win_size.push_back(end - start);
...@@ -198,21 +266,32 @@ struct pooling ...@@ -198,21 +266,32 @@ struct pooling
}); });
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
auto input_lens = args[0].get_shape().lens();
std::vector<std::size_t> kernel_dims;
if(dyn_global)
{ {
argument result{output_shape}; kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
}
else
{
kernel_dims = this->lengths;
}
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type; using type = typename decltype(output)::value_type;
switch(mode) switch(mode)
{ {
case migraphx::op::pooling_mode::average: case migraphx::op::pooling_mode::average:
calc_pooling<type>(output_shape, output, input, avg_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, avg_pool{});
break; break;
case migraphx::op::pooling_mode::max: case migraphx::op::pooling_mode::max:
calc_pooling<type>(output_shape, output, input, max_pool{}); calc_pooling<type>(dyn_out.computed_shape, output, input, kernel_dims, max_pool{});
break; break;
case migraphx::op::pooling_mode::lpnorm: case migraphx::op::pooling_mode::lpnorm:
calc_pooling<type>(output_shape, output, input, lpnorm_pool{lp_order}); calc_pooling<type>(
dyn_out.computed_shape, output, input, kernel_dims, lpnorm_pool{lp_order});
break; break;
} }
}); });
......
...@@ -43,7 +43,6 @@ struct quant_convolution ...@@ -43,7 +43,6 @@ struct quant_convolution
padding_mode_t padding_mode = default_; padding_mode_t padding_mode = default_;
int group = 1; int group = 1;
bool use_dynamic_same_auto_pad = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -52,8 +51,7 @@ struct quant_convolution ...@@ -52,8 +51,7 @@ struct quant_convolution
f(self.stride, "stride"), f(self.stride, "stride"),
f(self.dilation, "dilation"), f(self.dilation, "dilation"),
f(self.padding_mode, "padding_mode"), f(self.padding_mode, "padding_mode"),
f(self.group, "group"), f(self.group, "group"));
f(self.use_dynamic_same_auto_pad, "use_dynamic_same_auto_pad"));
} }
value attributes() const value attributes() const
...@@ -65,8 +63,8 @@ struct quant_convolution ...@@ -65,8 +63,8 @@ struct quant_convolution
void check_attribute_size() const void check_attribute_size() const
{ {
if(not((padding.size() == stride.size() or (padding.size() / 2) == stride.size()) and if((padding.size() != stride.size() and (padding.size() / 2) != stride.size()) or
stride.size() == dilation.size())) stride.size() != dilation.size())
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes"); MIGRAPHX_THROW("QUANT_CONVOLUTION: inconsistent attribute sizes");
} }
......
...@@ -53,15 +53,15 @@ struct softmax ...@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
if(inputs.at(0).packed()) auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{ {
return inputs.at(0); return s0;
} }
else else
{ {
auto lens = inputs.at(0).lens(); return {s0.type(), s0.lens()};
return {inputs.at(0).type(), lens};
} }
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -54,14 +55,46 @@ struct squeeze ...@@ -54,14 +55,46 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic())
{
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return input_shape.dyn_dims()[axis] != 1;
}))
{
MIGRAPHX_THROW(
"SQUEEZE: dynamic axis dimension should be equal to {1, 1, 0} or {1, 1, 1}");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
if(axes.empty())
{
std::copy_if(input_shape.dyn_dims().cbegin(),
input_shape.dyn_dims().cend(),
std::back_inserter(dyn_dims),
[&](auto dd) { return dd != 1; });
}
else
{
for(auto i : range(input_shape.ndim()))
{
if(std::find(axes.begin(), axes.end(), i) == axes.end())
{
dyn_dims.push_back(input_shape.dyn_dims()[i]);
}
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("SQUEEZE: static axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides; std::vector<std::size_t> new_strides;
...@@ -96,10 +129,11 @@ struct squeeze ...@@ -96,10 +129,11 @@ struct squeeze
return shape{type, new_lens, new_strides}; return shape{type, new_lens, new_strides};
} }
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,17 +46,15 @@ struct transpose ...@@ -45,17 +46,15 @@ struct transpose
} }
std::string name() const { return "transpose"; } std::string name() const { return "transpose"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
auto input_lens = input.lens();
auto input_strides = input.strides();
auto t = input.type();
if(dims.size() != input_lens.size()) if(dims.size() != input.ndim())
{ {
MIGRAPHX_THROW("Permutation has wrong number of axes"); MIGRAPHX_THROW("TRANSPOSE: Permutation has wrong number of axes");
} }
std::vector<int64_t> axes(dims.size()); std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0); std::iota(axes.begin(), axes.end(), 0);
...@@ -63,19 +62,36 @@ struct transpose ...@@ -63,19 +62,36 @@ struct transpose
{ {
MIGRAPHX_THROW("TRANSPOSE: Invalid permutation"); MIGRAPHX_THROW("TRANSPOSE: Invalid permutation");
} }
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size()); if(input.dynamic())
for(std::size_t i = 0; i < output_lens.size(); i++) {
std::vector<shape::dynamic_dimension> output_dyn_dims(input.ndim());
std::transform(dims.cbegin(), dims.cend(), output_dyn_dims.begin(), [&](auto dim) {
return input.dyn_dims()[dim];
});
return {input.type(), output_dyn_dims};
}
else
{
auto input_lens = input.lens();
auto input_strides = input.strides();
std::vector<size_t> output_lens(input.ndim());
std::vector<size_t> output_strides(input.ndim());
for(std::size_t i = 0; i < input.ndim(); i++)
{ {
output_lens[i] = input_lens[dims[i]]; output_lens[i] = input_lens[dims[i]];
output_strides[i] = input_strides[dims[i]]; output_strides[i] = input_strides[dims[i]];
} }
return {t, output_lens, output_strides}; return {input.type(), output_lens, output_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -62,9 +63,9 @@ struct unary : op_name<Derived> ...@@ -62,9 +63,9 @@ struct unary : op_name<Derived>
value attributes() const { return base_attributes(); } value attributes() const { return base_attributes(); }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(1); check_shapes{inputs, static_cast<const Derived&>(*this), true}.has(1);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.scalar()) if(s.dynamic() or s.scalar())
{ {
return s; return s;
} }
...@@ -78,9 +79,9 @@ struct unary : op_name<Derived> ...@@ -78,9 +79,9 @@ struct unary : op_name<Derived>
} }
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
std::transform(input.begin(), std::transform(input.begin(),
......
...@@ -29,11 +29,20 @@ ...@@ -29,11 +29,20 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -56,8 +65,33 @@ struct unsqueeze ...@@ -56,8 +65,33 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
if(input_shape.dynamic())
{
if(not steps.empty())
{
MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
}
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
if(std::find(axes.begin(), axes.end(), i) != axes.end())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
}
else
{
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides(); auto old_strides = input_shape.strides();
...@@ -110,9 +144,10 @@ struct unsqueeze ...@@ -110,9 +144,10 @@ struct unsqueeze
} }
return shape{type, new_lens, new_strides}; return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -199,9 +201,12 @@ auto compute_op(rank<1>, ...@@ -199,9 +201,12 @@ auto compute_op(rank<1>,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input)) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output_shape, input)),
input))
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -220,9 +225,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -220,9 +225,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{ {
return x.compute(output_shape, input); return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -244,9 +249,11 @@ auto compute_op(rank<1>, ...@@ -244,9 +249,11 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -278,9 +285,17 @@ auto compute_op(rank<4>, ...@@ -278,9 +285,17 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) F f) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f))
{ {
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); return x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
} }
template <class T, class F> template <class T, class F>
...@@ -290,9 +305,11 @@ auto compute_op(rank<3>, ...@@ -290,9 +305,11 @@ auto compute_op(rank<3>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -302,9 +319,10 @@ auto compute_op(rank<2>, ...@@ -302,9 +319,10 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs)) F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{ {
return x.compute(output, inputs); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -314,9 +332,12 @@ auto compute_op(rank<1>, ...@@ -314,9 +332,12 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) F) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs))
{ {
return x.compute(auto_any_cast(ctx), output, inputs); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -348,7 +369,8 @@ auto is_context_free_op(rank<1>, ...@@ -348,7 +369,8 @@ auto is_context_free_op(rank<1>,
const T& x, const T& x,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{}); -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input),
std::true_type{});
template <class T> template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&) auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
#include <migraphx/op/atanh.hpp> #include <migraphx/op/atanh.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp> #include <migraphx/op/capture.hpp>
......
...@@ -24,9 +24,10 @@ ...@@ -24,9 +24,10 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP #define MIGRAPHX_GUARD_OPERATORS_PAD_CALC_HPP
#include <migraphx/config.hpp>
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx, ...@@ -42,18 +43,21 @@ void calculate_padding(int64_t idx,
/*! /*!
* Calculate the padding for auto_padding. Used for dynamic shapes * Calculate the padding for auto_padding. Used for dynamic shapes
* where the padding calculation must be done at evaluation time. * where the padding calculation must be done at evaluation time.
* \param tensor_lens input tensor image shape
* \param k_lens weights kernel shape
* \param strides strides for the kernel
* \param dilations dilations for the kernel
* \param use_upper put odd padding on upper or lower side
* \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...} * \return padding in the form of {x0_begin, x1_begin, ... x0_end , x1_end, ...}
*/ */
std::vector<std::size_t> calc_dyn_auto_pad(std::vector<std::size_t> tensor_lens, std::vector<std::size_t> calc_dyn_auto_pad(const std::vector<std::size_t>& input_lens,
std::vector<std::size_t> k_lens, const std::vector<std::size_t>& wei_lens,
std::vector<std::size_t> strides, const std::vector<std::size_t>& strides,
std::vector<std::size_t> dilations, const std::vector<std::size_t>& dilations,
bool use_upper = true); bool use_upper);
// Used for dynamic auto padding of convolution operators since padding needs to be computed at
// evaulation time.
shape compute_padded_shape(const shape& input,
const shape& weights,
const std::vector<std::size_t>& padding,
const std::vector<std::size_t>& stride,
const std::vector<std::size_t>& dilation);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -115,6 +115,7 @@ struct program ...@@ -115,6 +115,7 @@ struct program
print_func) const; print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_py(std::ostream& os) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
......
...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector) ...@@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
} }
template <class T> template <class T>
auto reflectable_impl(rank<1>, T&& x) auto reflectable_impl(rank<1>, const T& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{}); -> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});
template <class T> template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{}); auto reflectable_impl(rank<0>, const T&) -> decltype(std::false_type{});
template <class T> template <class T>
struct remove_rvalue_reference struct remove_rvalue_reference
...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f) ...@@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template <class T> template <class T>
auto reflect_tie(T& x) auto reflect_tie(T& x)
{ {
return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })( return reflect(x, [](auto&& y, auto&&...) {
[](auto&&... xs) { return detail::auto_tuple(xs.get()...); }); // cppcheck-suppress UnnecessaryElseStatement
if constexpr(is_reflectable<decltype(y)>{})
{
auto t = reflect_tie(y);
return detail::wrap<decltype(t)>(t);
}
else
{
return detail::wrap<decltype(y)>(y);
}
})([](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
} }
template <class T, class F> template <class T, class F>
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -89,7 +90,10 @@ struct shape ...@@ -89,7 +90,10 @@ struct shape
std::size_t opt = 0; std::size_t opt = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f); static auto reflect(Self& self, F f)
{
return pack(f(self.min, "min"), f(self.max, "max"), f(self.opt, "opt"));
}
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
...@@ -97,6 +101,12 @@ struct shape ...@@ -97,6 +101,12 @@ struct shape
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x);
// compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
...@@ -115,6 +125,12 @@ struct shape ...@@ -115,6 +125,12 @@ struct shape
shape(type_t t, std::vector<dynamic_dimension> dims); shape(type_t t, std::vector<dynamic_dimension> dims);
// Construct a dynamic shape from three sets of lengths (of the same rank)
shape(type_t t,
std::vector<std::size_t> mins,
std::vector<std::size_t> maxes,
std::vector<std::size_t> opts);
template <class Range> template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end())) shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{ {
...@@ -136,6 +152,12 @@ struct shape ...@@ -136,6 +152,12 @@ struct shape
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*!
* The number of dimensions in the shape.
* Same as the number of indices required to get a data value.
*/
std::size_t ndim() const;
/*! /*!
* Return the number of elements in the tensor. * Return the number of elements in the tensor.
*/ */
...@@ -221,6 +243,9 @@ struct shape ...@@ -221,6 +243,9 @@ struct shape
shape with_type(type_t t) const; shape with_type(type_t t) const;
// convert the shape to an equivalent dynamic shape
shape to_dynamic() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); friend std::ostream& operator<<(std::ostream& os, const shape& x);
......
...@@ -31,6 +31,9 @@ ...@@ -31,6 +31,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Iterates the given function over the indices from the shape in order.
*/
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -51,7 +54,6 @@ void shape_for_each(const migraphx::shape& s, F f)
call(indices); call(indices);
} }
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment