"vscode:/vscode.git/clone" did not exist on "a97bfd3f0b363cf269765cd8fbadfaad506c17ad"
Unverified Commit 0b2bcf2c authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into add_parity_check_ci

parents fbacefca dcc7b0a5
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "precision.hpp" #include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -37,18 +38,18 @@ void verify_program(const std::string& name, ...@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction ...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction
const std::vector<module_ref>& module_inputs() const; const std::vector<module_ref>& module_inputs() const;
/// Where this instruction is used as an input to another instruction
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y); friend bool operator==(const instruction& x, const instruction& y);
......
...@@ -49,17 +49,22 @@ struct allocate ...@@ -49,17 +49,22 @@ struct allocate
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
migraphx::check_shapes{inputs, *this, true}.has(0, 1);
// check if shape attribute is not default
if(s != shape()) if(s != shape())
{ {
if(inputs.size() == 1)
{
migraphx::check_shapes{inputs, *this, false}.only_dims(1);
}
else
{
migraphx::check_shapes{inputs, *this, false}.has(0);
}
return s; return s;
} }
else else
{ {
migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
const auto& out_dims = inputs.at(0); const auto& out_dims = inputs.at(0);
assert(not out_dims.dynamic());
assert(out_dims.ndim() == 1);
std::size_t max_val = std::numeric_limits<std::size_t>::max(); std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,8 +46,6 @@ struct reshape ...@@ -45,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape dyn_compute_shape(shape s0) const shape dyn_compute_shape(shape s0) const
...@@ -110,27 +109,9 @@ struct reshape ...@@ -110,27 +109,9 @@ struct reshape
return it; return it;
} }
template <class DimIterator, class StrideIterator> // This will attempt to alias the dimensions of the input shape to the lens of
static auto can_strides_merge(DimIterator dim_start, // `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
DimIterator dim_last, // can remove previous nullopts that were sent back for the alias case
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims) static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{ {
if(input.standard()) if(input.standard())
...@@ -155,13 +136,8 @@ struct reshape ...@@ -155,13 +136,8 @@ struct reshape
{ {
auto start = idims.begin() + i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((i + n) <= istrides.size()); assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
...@@ -170,8 +146,7 @@ struct reshape ...@@ -170,8 +146,7 @@ struct reshape
{ {
auto start = rdims.begin() + i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((r + n) <= rdims.size()); assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
...@@ -191,15 +166,11 @@ struct reshape ...@@ -191,15 +166,11 @@ struct reshape
auto stride = rstrides.back(); auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if(d != 1) (void)d;
return nullopt;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides}; return shape{input.type(), rdims, rstrides};
} }
...@@ -233,25 +204,24 @@ struct reshape ...@@ -233,25 +204,24 @@ struct reshape
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements()) if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s; return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
auto s0 = inputs.front();
if(s0.dynamic()) if(s0.dynamic())
{ {
return dyn_compute_shape(s0); return dyn_compute_shape(s0);
...@@ -264,10 +234,14 @@ struct reshape ...@@ -264,10 +234,14 @@ struct reshape
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); assert(dyn_out.computed_shape.standard());
} argument result{dyn_out.computed_shape};
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
}; };
} // namespace op } // namespace op
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#define MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reshape_lazy
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape_lazy"; }
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("reshape_lazy: Only supports one non-fixed dynamic_dimension");
}
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"reshape_lazy: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("reshape_lazy: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_lazy_dims(const shape& input,
const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
auto s = reshape_lazy_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -124,7 +124,7 @@ struct roialign ...@@ -124,7 +124,7 @@ struct roialign
{ {
xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] +
(i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii];
xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii];
if(xy[ii] < -1.0 or xy[ii] > dims[ii]) if(xy[ii] < -1.0 or xy[ii] > dims[ii])
{ {
results[index] = pos_weight{}; results[index] = pos_weight{};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Convert dynamic ops to their static version if possible.
* Should be run after the split_single_dyn_dims pass.
*/
struct MIGRAPHX_EXPORT simplify_dyn_ops
{
std::string name() const { return "simplify_dyn_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify { namespace verify {
...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_rms_error != nullptr)
*out_error = error; *out_rms_error = error;
return error <= threshold; return error <= threshold;
} }
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify } // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
const argument& target_arg, verify::tolerance);
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const ...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
mpm.run_pass(simplify_reshapes{}); // loop to further optimize after initial transformations
mpm.run_pass(simplify_algebra{}); for(int j = 0; j < 2; j++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
}
mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{}); mpm.run_pass(propagate_constant{});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propagate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
......
...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const ...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const
auto lens = s.lens(); auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::vector<std::int64_t> axes(lens.size() - 2);
std::int64_t c = s.lens()[1]; std::iota(axes.begin(), axes.end(), 2);
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
} }
// max pooling // max pooling
else else
{ {
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
} }
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
...@@ -1325,48 +1325,59 @@ struct find_split_reshape ...@@ -1325,48 +1325,59 @@ struct find_split_reshape
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto input = slc->inputs().front();
auto split_outputs = get_splits(input); auto split_outputs = get_splits(input);
if(split_outputs.empty()) if(split_outputs.empty())
{ {
return; return;
} }
// Only want to apply this optimization if each split output is followed by // Find all the reshapes (similar to rsp) that can be simplified
// a contiguous op and a reshape std::vector<instruction_ref> conts;
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) { std::vector<instruction_ref> vec_rsp;
if(i->outputs().size() == 1)
{ // Iterate through slice and contiguous outputs to allow simplifications when
auto cont = i->outputs().front(); // slice is followed by multiple reshapes
return cont->outputs().size() != 1; for(auto& i : split_outputs)
}
return false;
}))
{ {
return; std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
} }
std::vector<instruction_ref> vec_rsp(split_outputs.size()); for(auto& i : conts)
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { {
auto cont = i->outputs().front(); std::copy_if(i->outputs().begin(),
return cont->outputs().front(); i->outputs().end(),
}); std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}
// all outputs are reshape and of the same shape // No simplification needed if there is only one slice -> cont -> reshape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims; if(vec_rsp.size() <= 1)
if(not same_ops(vec_rsp))
{ {
return; return;
} }
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>()); slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
...@@ -1393,16 +1404,67 @@ struct find_split_reshape ...@@ -1393,16 +1404,67 @@ struct find_split_reshape
{ {
rsp_axis = std::distance(rsp_strides.begin(), ait); rsp_axis = std::distance(rsp_strides.begin(), ait);
} }
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { // Calculate reshape output shape
return is->get_shape().lens()[rsp_axis]; // Need to find a reshape such that data represented by instructions in vec_rsp can be
}); // written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); // cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard()) if(not input->get_shape().standard())
...@@ -1413,16 +1475,14 @@ struct find_split_reshape ...@@ -1413,16 +1475,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
m.replace_instruction( m.replace_instruction(
vec_rsp[i], vec_rsp[i],
make_op( make_op(
"slice", "slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), {{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins); rsp_ins);
start += vec_dims[i];
} }
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -122,6 +122,11 @@ struct find_nop_reshapes ...@@ -122,6 +122,11 @@ struct find_nop_reshapes
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
reshapes.insert("transpose"); reshapes.insert("transpose");
reshapes.insert("reduce_mean");
reshapes.insert("reduce_max");
reshapes.insert("reduce_min");
reshapes.insert("reduce_sum");
reshapes.insert("reduce_prod");
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
...@@ -627,6 +632,30 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +632,30 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -799,6 +828,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -799,6 +828,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -23,6 +23,10 @@ ...@@ -23,6 +23,10 @@
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen) find_package(miopen)
# rocblas # rocblas
......
...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops ...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) { [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op ...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("convolution"); } auto matcher() const { return is_mlir_conv; }
}; };
struct find_mlir_standalone_dot_op : find_mlir_standalone_op struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("dot"); } auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
}; };
/** /**
...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx)
{ {
return true; return true;
} }
else if(op_name == "convolution") else if(op_name == "convolution" or op_name == "quant_convolution")
{ {
if(ctx == nullptr) if(ctx == nullptr)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment