Unverified Commit 38a62ed2 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Normalize_attributes dynamic shapes update (#1953)

throw on use_len with non-fixed dynamic dimensions
change normalize_attributes to use input shape rather than input dimensions
parent 68a9a23f
...@@ -43,7 +43,7 @@ template <class T, class... Ts> ...@@ -43,7 +43,7 @@ template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type; using dependent_type = typename select_dependent_type<T, Ts...>::type;
MIGRAPHX_EXPORT MIGRAPHX_EXPORT
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens); bool normalize_attributes(operation& op, const shape& input_shape);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
...@@ -467,7 +467,7 @@ operation instruction::normalized_operator() const ...@@ -467,7 +467,7 @@ operation instruction::normalized_operator() const
if(this->need_normalization()) if(this->need_normalization())
{ {
auto s = this->inputs().front()->get_shape(); auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s.max_lens())) if(not normalize_attributes(o, s))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,8 +35,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,8 +35,9 @@ inline namespace MIGRAPHX_INLINE_NS {
* vec: the vector attribute to normalize * vec: the vector attribute to normalize
* axes: the operator's axes attribute if it exists, empty otherwise * axes: the operator's axes attribute if it exists, empty otherwise
* val: the normalize_axes key and options. Ex: normalize["axes"] = * val: the normalize_axes key and options. Ex: normalize["axes"] =
* value::array{normalize_attribute::include_min}; lens: shape dimensions passed when calling * value::array{normalize_attribute::include_min};
* normalize_attributes(op&, lens) * input_shape: input shape passed when calling
* normalize_attributes(op&, input_shape)
* *
* See normalize_attribute.hpp for explaining the options. * See normalize_attribute.hpp for explaining the options.
*/ */
...@@ -44,11 +45,11 @@ template <class Message> ...@@ -44,11 +45,11 @@ template <class Message>
auto tune_attribute(const std::vector<int64_t>& vec, auto tune_attribute(const std::vector<int64_t>& vec,
const std::vector<int64_t>& axes, const std::vector<int64_t>& axes,
const value& val, const value& val,
const std::vector<std::size_t>& lens, const shape& input_shape,
Message m) Message m)
{ {
std::vector<int64_t> result(vec); std::vector<int64_t> result(vec);
int64_t n_rank = lens.size(); int64_t n_rank = input_shape.ndim();
std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>(); std::vector<op::normalize_attribute> vec_attrs = val.to_vector<op::normalize_attribute>();
if(contains(vec_attrs, op::normalize_attribute::use_output)) if(contains(vec_attrs, op::normalize_attribute::use_output))
{ {
...@@ -56,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec, ...@@ -56,9 +57,28 @@ auto tune_attribute(const std::vector<int64_t>& vec,
} }
std::vector<int64_t> max_vals(vec.size(), n_rank); std::vector<int64_t> max_vals(vec.size(), n_rank);
if(contains(vec_attrs, op::normalize_attribute::use_len)) if(contains(vec_attrs, op::normalize_attribute::use_len))
{ {
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) { return lens[i]; }); if(input_shape.dynamic())
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
const auto& dd = input_shape.dyn_dims().at(i);
if(not dd.is_fixed())
{
MIGRAPHX_THROW(
"NORMALIZE_ATTR: 'use_lens' on a non-fixed dynamic dimension, axis=" +
std::to_string(i));
}
return dd.max;
});
}
else
{
std::transform(axes.begin(), axes.end(), max_vals.begin(), [&](auto i) {
return input_shape.lens().at(i);
});
}
} }
if(contains(vec_attrs, op::normalize_attribute::clip_max)) if(contains(vec_attrs, op::normalize_attribute::clip_max))
...@@ -159,9 +179,9 @@ auto tune_pad_attribute(const value& val) ...@@ -159,9 +179,9 @@ auto tune_pad_attribute(const value& val)
/** /**
* Assumptions: * Assumptions:
* Dimensions to pad start from the third dimension (index 2). * Dimensions to pad start from the third dimension (index 2).
* Called by compute_shape_op() with the `lens` of the first input. * Called by compute_shape_op() with the shape of the first input.
*/ */
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) bool normalize_attributes(operation& op, const shape& input_shape)
{ {
bool tuned = false; bool tuned = false;
auto attrs = op.attributes(); auto attrs = op.attributes();
...@@ -172,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -172,9 +192,9 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
auto padding_size = padding.size(); auto padding_size = padding.size();
auto padding_start = 2; auto padding_start = 2;
if(padding_size == 2 * (lens.size() - padding_start)) if(padding_size == 2 * (input_shape.ndim() - padding_start))
tuned = true; tuned = true;
else if(padding_size != (lens.size() - padding_start)) else if(padding_size != (input_shape.ndim() - padding_start))
MIGRAPHX_THROW("inconsistent padding size"); MIGRAPHX_THROW("inconsistent padding size");
else else
{ {
...@@ -205,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -205,7 +225,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes = val.at("axes").without_key().to_vector<int64_t>(); axes = val.at("axes").without_key().to_vector<int64_t>();
} }
auto vec = vv.to_vector<int64_t>(); auto vec = vv.to_vector<int64_t>();
auto result = tune_attribute(vec, axes, rv.without_key(), lens, message); auto result = tune_attribute(vec, axes, rv.without_key(), input_shape, message);
val[key] = result; val[key] = result;
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
...@@ -214,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens) ...@@ -214,7 +234,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else else
{ {
auto num = vv.to<int64_t>(); auto num = vv.to<int64_t>();
auto result = tune_attribute({num}, {num}, rv.without_key(), lens, message); auto result = tune_attribute({num}, {num}, rv.without_key(), input_shape, message);
val[key] = result.front(); val[key] = result.front();
op.from_value(val); op.from_value(val);
val = op.to_value(); val = op.to_value();
......
...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const ...@@ -45,7 +45,7 @@ void normalize_ops::apply(module& m) const
auto s = inputs[0]->get_shape(); auto s = inputs[0]->get_shape();
migraphx::operation tuned_op = ins->get_operator(); migraphx::operation tuned_op = ins->get_operator();
if(normalize_attributes(tuned_op, s.max_lens())) if(normalize_attributes(tuned_op, s))
{ {
m.replace_instruction(ins, tuned_op, inputs); m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized(); ins->set_normalized();
......
...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -143,7 +143,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
if(inputs.empty()) if(inputs.empty())
MIGRAPHX_THROW("At least one input is required for " + x.name()); MIGRAPHX_THROW("At least one input is required for " + x.name());
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].max_lens()); normalize_attributes(y, inputs[0]);
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment