Commit cd4ab535 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 3891ee58 a0fa3742
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
#include <migraphx/config.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
std::size_t hash_value(const T& v)
{
return std::hash<T>{}(v);
}
template <class T>
void hash_combine(std::size_t& seed, const T& v)
{
seed ^= hash_value(v) + 0x9e3779b9 + (seed << 6u) + (seed >> 2u);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_HASH_HPP
...@@ -136,6 +136,9 @@ struct instruction ...@@ -136,6 +136,9 @@ struct instruction
operation normalized_operator() const; operation normalized_operator() const;
std::size_t get_target_id() const;
void set_target_id(std::size_t tid);
void debug_print() const; void debug_print() const;
static void print(std::ostream& os, static void print(std::ostream& os,
...@@ -172,7 +175,8 @@ struct instruction ...@@ -172,7 +175,8 @@ struct instruction
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
std::vector<module_ref> module_args; std::vector<module_ref> module_args;
literal lit; literal lit;
bool normalized = false; bool normalized = false;
std::size_t target_id = 0;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -35,6 +35,10 @@ ...@@ -35,6 +35,10 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#ifndef MIGRAPHX_USE_TYPE_ERASED_MATCHERS
#define MIGRAPHX_USE_TYPE_ERASED_MATCHERS 0
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -103,6 +107,13 @@ struct predicate_matcher ...@@ -103,6 +107,13 @@ struct predicate_matcher
} }
}; };
/// Convert a predicate function into a matcher
template <class P>
predicate_matcher<P> make_predicate_matcher(P p)
{
return {p};
}
/// Convert a function into a matcher /// Convert a function into a matcher
template <class F> template <class F>
struct function_matcher struct function_matcher
...@@ -124,14 +135,14 @@ template <class M> ...@@ -124,14 +135,14 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[=, name = std::move(name)](matcher_context& ctx, [=, m_name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result) if(result)
{ {
if(not ctx.has_instruction(ins)) if(not ctx.has_instruction(ins))
return nullopt; return nullopt;
ctx.instructions[name] = ins; ctx.instructions[m_name] = ins;
} }
return result; return result;
}); });
...@@ -183,14 +194,26 @@ struct id_matcher ...@@ -183,14 +194,26 @@ struct id_matcher
template <class M> template <class M>
struct basic_matcher; struct basic_matcher;
struct any_matcher;
template <class M>
struct type_erased_matcher
{
#if MIGRAPHX_USE_TYPE_ERASED_MATCHERS
using type = any_matcher;
#else
using type = basic_matcher<M>;
#endif
};
template <class M> template <class M>
basic_matcher<M> make_basic_matcher(M m); typename type_erased_matcher<M>::type make_basic_matcher(M m);
template <class F> template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f); auto make_basic_fun_matcher(F f);
template <class P> template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p); auto make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
...@@ -222,38 +245,38 @@ struct basic_matcher ...@@ -222,38 +245,38 @@ struct basic_matcher
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
}; };
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
template <class M> template <class M>
basic_matcher<M> make_basic_matcher(M m) typename type_erased_matcher<M>::type make_basic_matcher(M m)
{ {
return {m}; return {m};
} }
/// Create a basic matcher from a function /// Create a basic matcher from a function
template <class F> template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f) auto make_basic_fun_matcher(F f)
{ {
return {{f}}; return make_basic_matcher(make_function_matcher(f));
} }
/// Create a basic matcher from a predicate function /// Create a basic matcher from a predicate function
template <class P> template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) auto make_basic_pred_matcher(P p)
{ {
return {{p}}; return make_basic_matcher(make_predicate_matcher(p));
} }
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher /// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
...@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m) ...@@ -347,31 +370,49 @@ match::matcher_result find_match(module& modl, M&& m)
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
/// Find matches for an instruction in the module /// Find matches for an instruction in the module for per section of matchers
template <class Mod, class... Ms> template <class Mod, class... Ms>
void find_matches(Mod& mod, instruction_ref ins, Ms&&... ms) void find_matches(size_t trace_pass, Mod& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
#endif #endif
int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
bool match = false; #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
if(trace > 1) if(trace > 1 or trace_pass > 1)
std::cout << "Match: " << get_type_name(m) << std::endl; std::cout << "Match: " << get_type_name(m) << std::endl;
auto r = match_instruction(get_module(mod), ins, m.matcher()); auto r = match_instruction(get_module(mod), ins, m.matcher());
if(r.result == get_module(mod).end()) if(r.result == get_module(mod).end())
return; return;
if(trace > 0) if(trace > 0 or trace_pass > 0)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
get_module(mod).debug_print(ins); get_module(mod).debug_print(ins);
} }
// If its already invalid dont validate it again
bool invalidated = validate and get_module(mod).validate() != get_module(mod).end();
m.apply(mod, r); m.apply(mod, r);
if(validate and not invalidated)
{
auto invalid = get_module(mod).validate();
if(invalid != get_module(mod).end())
{
std::cout << "Invalid program from match: " << get_type_name(m) << std::endl;
std::cout << "Invalid instructions: " << std::endl;
get_module(mod).debug_print(invalid->inputs());
get_module(mod).debug_print(invalid);
}
}
match = true; match = true;
}, },
ms...); ms...);
...@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms) ...@@ -383,7 +424,17 @@ void find_matches(Mod& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(get_module(mod))) for(auto ins : iterator_for(get_module(mod)))
{ {
find_matches(mod, ins, ms...); find_matches(0, mod, ins, ms...);
}
}
/// Find matches in a pass
template <class Mod, class... Ms>
void find_matches(size_t trace_pass, Mod& mod, Ms&&... ms)
{
for(auto ins : iterator_for(get_module(mod)))
{
find_matches(trace_pass, mod, ins, ms...);
} }
} }
...@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins) ...@@ -520,6 +571,8 @@ MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{ {
return not ins->get_shape().standard(); return not ins->get_shape().standard();
} }
MIGRAPHX_PRED_MATCHER(dynamic_shape, instruction_ref ins) { return ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(static_shape, instruction_ref ins) { return not ins->get_shape().dynamic(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
...@@ -612,9 +665,9 @@ auto skip_output(Ms... ms) ...@@ -612,9 +665,9 @@ auto skip_output(Ms... ms)
inline auto var(std::string s) inline auto var(std::string s)
{ {
return make_basic_fun_matcher( return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx, [=, m_s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> { instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s); auto it = ctx.instructions.find(m_s);
if(it == ctx.instructions.end()) if(it == ctx.instructions.end())
return nullopt; return nullopt;
return it->second; return it->second;
...@@ -624,7 +677,7 @@ inline auto var(std::string s) ...@@ -624,7 +677,7 @@ inline auto var(std::string s)
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; }); [=, m_s = std::move(s)](instruction_ref ins) { return ins->name() == m_s; });
} }
inline auto name_contains(const std::string& name) inline auto name_contains(const std::string& name)
...@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name) ...@@ -635,8 +688,8 @@ inline auto name_contains(const std::string& name)
inline auto name(std::unordered_set<std::string> names) inline auto name(std::unordered_set<std::string> names)
{ {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) { return make_basic_pred_matcher([=, m_names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0; return m_names.count(ins->name()) > 0;
}); });
} }
......
...@@ -178,6 +178,8 @@ struct module ...@@ -178,6 +178,8 @@ struct module
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
std::vector<instruction_ref> get_returns() const;
std::size_t size() const; std::size_t size() const;
instruction_ref begin() const; instruction_ref begin() const;
instruction_ref end() const; instruction_ref end() const;
......
...@@ -26,10 +26,12 @@ ...@@ -26,10 +26,12 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <functional>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void to_msgpack(const value& v, std::function<void(const char*, std::size_t)> writer);
std::vector<char> to_msgpack(const value& v); std::vector<char> to_msgpack(const value& v);
value from_msgpack(const std::vector<char>& buffer); value from_msgpack(const std::vector<char>& buffer);
value from_msgpack(const char* buffer, std::size_t size); value from_msgpack(const char* buffer, std::size_t size);
......
...@@ -37,7 +37,7 @@ struct onnx_options ...@@ -37,7 +37,7 @@ struct onnx_options
std::size_t default_dim_value = 0; std::size_t default_dim_value = 0;
/// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set /// Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value set
/// parser throws) /// parser throws)
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0}; shape::dynamic_dimension default_dyn_dim_value = {1, 1};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims /// Explicitly specify dynamic dims of an input (if both map_input_dims and map_dyn_input_dims
......
...@@ -62,7 +62,7 @@ struct argmax ...@@ -62,7 +62,7 @@ struct argmax
if(s0.dynamic()) if(s0.dynamic())
{ {
auto dyn_dims = s0.dyn_dims(); auto dyn_dims = s0.dyn_dims();
dyn_dims[axis] = {1, 1, 0}; dyn_dims[axis] = {1, 1};
return {shape::int64_type, dyn_dims}; return {shape::int64_type, dyn_dims};
} }
else else
......
...@@ -37,10 +37,13 @@ namespace op { ...@@ -37,10 +37,13 @@ namespace op {
* 1 input version: * 1 input version:
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of * Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension * broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank * that stays the same.
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would * ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1.
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3] *
* with axis = 0 * For higher rank input shapes, axis is an offset parameter for the broadcasting.
* Such that this operator would work in the opposite direction of NumPy broadcasting
* (left-most to rightwards element-wise comparison)
* ex: broadcasting shape [2, 2] -> [2, 2, 3] with axis = 0
* *
* 2 input version: * 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter. * Broadcast the first input 1D shape into the second input shape based on the axis parameter.
...@@ -68,6 +71,9 @@ struct broadcast ...@@ -68,6 +71,9 @@ struct broadcast
{ {
// the ONNX broadcast op is deprecated now, so not handling the negative // the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore // value of axis anymore
if(s0.dynamic())
MIGRAPHX_THROW(
"BROADCAST: Single dynamic input shape not supported. Use two inputs.");
if(axis >= broadcast_lens.size()) if(axis >= broadcast_lens.size())
{ {
MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) + MIGRAPHX_THROW("BROADCAST : axis " + migraphx::to_string(axis) +
......
...@@ -134,7 +134,7 @@ struct concat ...@@ -134,7 +134,7 @@ struct concat
} }
auto new_dims = inputs[0].dyn_dims(); auto new_dims = inputs[0].dyn_dims();
new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max, 0}; new_dims[axis] = migraphx::shape::dynamic_dimension{new_min, new_max};
return {inputs[0].type(), new_dims}; return {inputs[0].type(), new_dims};
} }
else else
......
...@@ -48,7 +48,7 @@ struct contiguous ...@@ -48,7 +48,7 @@ struct contiguous
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front(); auto s0 = inputs.front();
if(s0.dynamic() or s0.standard()) if(s0.dynamic())
{ {
return s0; return s0;
} }
......
...@@ -66,7 +66,17 @@ struct convert : unary<convert> ...@@ -66,7 +66,17 @@ struct convert : unary<convert>
auto type = target_type; auto type = target_type;
return [type](auto x) { return [type](auto x) {
auto y = x; auto y = x;
shape::visit(type, [&](auto as) { y = std::min(std::max(as(x), as.min()), as.max()); }); shape::visit(type, [&](auto as) {
// clamping value between target_type's max and min doesn't work for NaNs,
if(std::isnan(x))
{
y = as.nan();
}
else
{
y = std::min(std::max(as(x), as.min()), as.max());
}
});
return y; return y;
}; };
} }
......
...@@ -24,9 +24,12 @@ ...@@ -24,9 +24,12 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP #define MIGRAPHX_GUARD_OPERATORS_CONVOLUTION_HPP
#include <migraphx/argument.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -35,6 +38,10 @@ namespace migraphx { ...@@ -35,6 +38,10 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Convolution operator. Does not support optimal dimensions for spatial dimensions. Returns empty
* optimals.
*/
struct convolution struct convolution
{ {
std::vector<std::size_t> padding = {0, 0}; std::vector<std::size_t> padding = {0, 0};
...@@ -145,7 +152,7 @@ struct convolution ...@@ -145,7 +152,7 @@ struct convolution
else else
{ {
auto l = input_shape.lens().at(0); auto l = input_shape.lens().at(0);
output_dyn_dims.push_back({l, l, 0}); output_dyn_dims.push_back({l, l});
} }
}; };
...@@ -162,25 +169,30 @@ struct convolution ...@@ -162,25 +169,30 @@ struct convolution
if(x_shape.dynamic()) if(x_shape.dynamic())
{ {
auto x = x_shape.dyn_dims()[i + 2]; auto x = x_shape.dyn_dims()[i + 2];
output_dyn_dims.push_back(shape::dynamic_dimension{ std::set<std::size_t> optimals{};
ceil_div(x.min, s), ceil_div(x.max, s), ceil_div(x.opt, s)}); std::transform(x.optimals.begin(),
x.optimals.end(),
std::inserter(optimals, optimals.begin()),
[&](auto o) { return ceil_div(o, s); });
output_dyn_dims.push_back(
shape::dynamic_dimension{ceil_div(x.min, s), ceil_div(x.max, s), optimals});
} }
else else
{ {
auto od = ceil_div(x_shape.lens()[i + 2], s); auto od = ceil_div(x_shape.lens()[i + 2], s);
output_dyn_dims.push_back(shape::dynamic_dimension{od, od, 0}); output_dyn_dims.push_back(shape::dynamic_dimension{od, od});
} }
} }
} }
else else
{ {
// Does not compute for optimals
auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens()); auto min_spatial_dims = calc_conv_lens(x_shape.min_lens(), w_shape.max_lens());
auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens()); auto max_spatial_dims = calc_conv_lens(x_shape.max_lens(), w_shape.min_lens());
auto opt_spatial_dims = calc_conv_lens(x_shape.opt_lens(), w_shape.opt_lens());
for(size_t i = 0; i < num_spatial_dims; ++i) for(size_t i = 0; i < num_spatial_dims; ++i)
{ {
output_dyn_dims.push_back(shape::dynamic_dimension{ output_dyn_dims.push_back(
min_spatial_dims[i], max_spatial_dims[i], opt_spatial_dims[i]}); shape::dynamic_dimension{min_spatial_dims[i], max_spatial_dims[i], {}});
} }
} }
return shape{x_shape.type(), output_dyn_dims}; return shape{x_shape.type(), output_dyn_dims};
...@@ -201,6 +213,37 @@ struct convolution ...@@ -201,6 +213,37 @@ struct convolution
check_attribute_size(); check_attribute_size();
return stride.size(); return stride.size();
} }
argument compute(shape output_shape, std::vector<argument> args) const
{
std::vector<std::size_t> new_padding;
if(padding_mode != op::padding_mode_t::default_)
{
auto input_lens = args[0].get_shape().lens();
auto weights_lens = args[1].get_shape().lens();
new_padding =
padding_mode == op::same_upper
? calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, true)
: calc_dyn_auto_pad(input_lens, weights_lens, stride, dilation, false);
output_shape = compute_padded_shape(
args[0].get_shape(), args[1].get_shape(), new_padding, stride, dilation);
}
else
{
new_padding = padding;
if(output_shape.dynamic())
{
output_shape =
normalize_compute_shape({args.at(0).get_shape(), args.at(1).get_shape()});
}
}
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
migraphx::convolution(output, input, weights, new_padding, stride, group);
});
return result;
}
}; };
} // namespace op } // namespace op
......
...@@ -37,10 +37,23 @@ namespace op { ...@@ -37,10 +37,23 @@ namespace op {
struct dequantizelinear struct dequantizelinear
{ {
value attributes() const
{
// Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}};
}
std::string name() const { return "dequantizelinear"; } std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims(); check_shapes{inputs, *this}.same_dims().has(2, 3);
if(inputs.size() == 3 and inputs[0].type() != inputs[2].type())
{
MIGRAPHX_THROW("DEQUANTIZELINEAR: Zero point and input should be the same type.");
}
return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()}; return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
} }
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -59,27 +60,22 @@ struct flatten ...@@ -59,27 +60,22 @@ struct flatten
auto s = inputs[0]; auto s = inputs[0];
if(s.dynamic()) if(s.dynamic())
{ {
// Doesn't handle optimals
auto min_lens = s.min_lens(); auto min_lens = s.min_lens();
auto max_lens = s.max_lens(); auto max_lens = s.max_lens();
auto opt_lens = s.opt_lens();
// If any of the opt values is 0, output opt will be 0 // If any of the opt values is 0, output opt will be 0
shape::dynamic_dimension x = { shape::dynamic_dimension x = {
std::accumulate( std::accumulate(
min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}), min_lens.begin(), min_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate( std::accumulate(
max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}), max_lens.begin(), max_lens.begin() + axis, std::size_t{1}, std::multiplies<>{}),
std::accumulate(opt_lens.begin(), {}};
opt_lens.begin() + axis,
std::size_t{1},
std::multiplies<>{})};
shape::dynamic_dimension y = { shape::dynamic_dimension y = {
std::accumulate( std::accumulate(
min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}), min_lens.begin() + axis, min_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate( std::accumulate(
max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}), max_lens.begin() + axis, max_lens.end(), std::size_t{1}, std::multiplies<>{}),
std::accumulate( {}};
opt_lens.begin() + axis, opt_lens.end(), std::size_t{1}, std::multiplies<>{}),
};
return {s.type(), {x, y}}; return {s.type(), {x, y}};
} }
else else
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment