"vscode:/vscode.git/clone" did not exist on "681ab11cc827cde278b610426f0f371b9a509892"
Unverified Commit 07bef2a0 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into enable_navi_32_ci

parents fc60486e dcc7b0a5
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,8 +46,6 @@ struct reshape ...@@ -45,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape dyn_compute_shape(shape s0) const shape dyn_compute_shape(shape s0) const
...@@ -110,27 +109,9 @@ struct reshape ...@@ -110,27 +109,9 @@ struct reshape
return it; return it;
} }
template <class DimIterator, class StrideIterator> // This will attempt to alias the dimensions of the input shape to the lens of
static auto can_strides_merge(DimIterator dim_start, // `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
DimIterator dim_last, // can remove previous nullopts that were sent back for the alias case
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims) static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{ {
if(input.standard()) if(input.standard())
...@@ -155,13 +136,8 @@ struct reshape ...@@ -155,13 +136,8 @@ struct reshape
{ {
auto start = idims.begin() + i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((i + n) <= istrides.size()); assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
...@@ -170,8 +146,7 @@ struct reshape ...@@ -170,8 +146,7 @@ struct reshape
{ {
auto start = rdims.begin() + i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((r + n) <= rdims.size()); assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
...@@ -191,15 +166,11 @@ struct reshape ...@@ -191,15 +166,11 @@ struct reshape
auto stride = rstrides.back(); auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if(d != 1) (void)d;
return nullopt;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides}; return shape{input.type(), rdims, rstrides};
} }
...@@ -233,25 +204,24 @@ struct reshape ...@@ -233,25 +204,24 @@ struct reshape
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements()) if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s; return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
auto s0 = inputs.front();
if(s0.dynamic()) if(s0.dynamic())
{ {
return dyn_compute_shape(s0); return dyn_compute_shape(s0);
...@@ -264,10 +234,14 @@ struct reshape ...@@ -264,10 +234,14 @@ struct reshape
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); assert(dyn_out.computed_shape.standard());
} argument result{dyn_out.computed_shape};
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
}; };
} // namespace op } // namespace op
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#define MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reshape_lazy
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape_lazy"; }
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("reshape_lazy: Only supports one non-fixed dynamic_dimension");
}
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"reshape_lazy: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("reshape_lazy: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_lazy_dims(const shape& input,
const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
auto s = reshape_lazy_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify { namespace verify {
...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_rms_error != nullptr)
*out_error = error; *out_rms_error = error;
return error <= threshold; return error <= threshold;
} }
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify } // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
const argument& target_arg, verify::tolerance);
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const ...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const
auto lens = s.lens(); auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::vector<std::int64_t> axes(lens.size() - 2);
std::int64_t c = s.lens()[1]; std::iota(axes.begin(), axes.end(), 2);
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
} }
// max pooling // max pooling
else else
{ {
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
} }
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
...@@ -122,6 +122,11 @@ struct find_nop_reshapes ...@@ -122,6 +122,11 @@ struct find_nop_reshapes
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
reshapes.insert("transpose"); reshapes.insert("transpose");
reshapes.insert("reduce_mean");
reshapes.insert("reduce_max");
reshapes.insert("reduce_min");
reshapes.insert("reduce_sum");
reshapes.insert("reduce_prod");
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
......
...@@ -23,6 +23,10 @@ ...@@ -23,6 +23,10 @@
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen) find_package(miopen)
# rocblas # rocblas
......
...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops ...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) { [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op ...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("convolution"); } auto matcher() const { return is_mlir_conv; }
}; };
struct find_mlir_standalone_dot_op : find_mlir_standalone_op struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("dot"); } auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
}; };
/** /**
...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx)
{ {
return true; return true;
} }
else if(op_name == "convolution") else if(op_name == "convolution" or op_name == "quant_convolution")
{ {
if(ctx == nullptr) if(ctx == nullptr)
{ {
......
...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise ...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)( return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front(); auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reshape_lazy.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
...@@ -89,7 +90,6 @@ struct miopen_apply ...@@ -89,7 +90,6 @@ struct miopen_apply
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
...@@ -115,6 +115,7 @@ struct miopen_apply ...@@ -115,6 +115,7 @@ struct miopen_apply
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op(); add_select_module_op();
add_reshape_lazy_op();
} }
void copy_params() const void copy_params() const
...@@ -376,6 +377,32 @@ struct miopen_apply ...@@ -376,6 +377,32 @@ struct miopen_apply
return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs()); return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
}); });
} }
/**
* Adds reshape lazy to reshape ops that can be aliased instead of copied.
* `gpu::contiguous` are added before and after the reshape; these contiguous
* instructions can be removed by the eliminate_contiguous pass.
*/
void add_reshape_lazy_op()
{
apply_map.emplace("reshape", [=](instruction_ref ins) {
std::vector<instruction_ref> before_contiguous_args = ins->inputs();
auto before_alloc = insert_allocation(ins, std::prev(ins)->get_shape());
before_contiguous_args.push_back(before_alloc);
auto before_contig =
mod->insert_instruction(ins, make_op("gpu::contiguous"), {before_contiguous_args});
auto new_lazy_reshape = mod->insert_instruction(
ins,
make_op("reshape_lazy", {{"dims", {ins->get_operator().to_value().at("dims")}}}),
before_contig);
std::vector<instruction_ref> after_contiguous_args = {new_lazy_reshape};
auto after_alloc = insert_allocation(new_lazy_reshape, new_lazy_reshape->get_shape());
after_contiguous_args.push_back(after_alloc);
return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args);
});
}
}; };
void lowering::apply(module_pass_manager& mpm) const void lowering::apply(module_pass_manager& mpm) const
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#include <ostream>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
#include <mlir-c/IR.h> #include <mlir-c/IR.h>
...@@ -34,6 +35,7 @@ ...@@ -34,6 +35,7 @@
#include <mlir-c/Dialect/Rock.h> #include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x) ...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
struct mlir_logger
{
std::stringstream ss;
mlir_context* ctx;
std::optional<MlirDiagnosticHandlerID> id;
mlir_logger() : ctx(nullptr), id(std::nullopt) {}
mlir_logger(mlir_context* context) : ctx(context)
{
id =
mlirContextAttachDiagnosticHandler(ctx->get(), mlir_diagnostic_print_cb, this, nullptr);
}
~mlir_logger()
{
if(id.has_value())
mlirContextDetachDiagnosticHandler(ctx->get(), *id);
}
mlir_logger(const mlir_logger& other) = delete;
mlir_logger& operator=(const mlir_logger& other) = delete;
mlir_logger(mlir_logger&& other) noexcept
: ss(std::move(other.ss)), ctx(other.ctx), id(other.id)
{
other.ctx = nullptr;
other.id = std::nullopt;
}
mlir_logger& operator=(mlir_logger other) noexcept
{
std::swap(ss, other.ss);
std::swap(ctx, other.ctx);
std::swap(id, other.id);
return *this;
}
std::string str() const { return ss.str(); }
void clear() { ss = std::stringstream{}; }
static MlirLogicalResult mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger);
MlirLogicalResult handle(MlirDiagnostic diag);
};
MlirLogicalResult mlir_logger::mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger)
{
return reinterpret_cast<mlir_logger*>(logger)->handle(diag);
}
MlirLogicalResult mlir_logger::handle(MlirDiagnostic diag)
{
MlirDiagnosticSeverity sev = mlirDiagnosticGetSeverity(diag);
switch(sev)
{
case MlirDiagnosticSeverity::MlirDiagnosticError: ss << "Error: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticWarning: ss << "Warning: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticNote: ss << "Note: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticRemark: ss << "Remark: "; break;
}
mlir_print(mlirDiagnosticPrint, diag, [&](auto s) { ss << s; });
ss << std::endl;
for(intptr_t i = 0, e = mlirDiagnosticGetNumNotes(diag); i < e; ++i)
{
(void)handle(mlirDiagnosticGetNote(diag, i));
}
return mlirLogicalResultSuccess();
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)), /*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location)),
logger(&ctx)
{ {
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get()); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
...@@ -614,21 +688,49 @@ struct mlir_program ...@@ -614,21 +688,49 @@ struct mlir_program
} }
} }
void run_high_level_pipeline() MIGRAPHX_TIDY_CONST void run_high_level_pipeline()
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
if(mlirLogicalResultIsFailure(
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()))))
{
std::string error = "Invalid MLIR created: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
void run_backend_pipeline() MIGRAPHX_TIDY_CONST void run_backend_pipeline()
{ {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
const size_t trace = value_of(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
auto mod_op = mlirModuleGetOperation(mmodule.get());
if(trace >= 2)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
if(mlirLogicalResultIsFailure(mlirPassManagerRunOnOp(pm_back.get(), mod_op)))
{
std::string error = "MLIR backend compilation failed: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST code_object_op compile(const value& solution)
{ {
// 1st pipeline to call // 1st pipeline to call
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -682,7 +784,7 @@ struct mlir_program ...@@ -682,7 +784,7 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str); MIGRAPHX_THROW("Failed setting tuning key: " + *str);
} }
tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive)
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -702,7 +804,8 @@ struct mlir_program ...@@ -702,7 +804,8 @@ struct mlir_program
if(perf_key_bytes > perf_key.size()) if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) + MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long"); " bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes); tc.solutions.emplace_back(
std::string(perf_key.begin(), perf_key.begin() + perf_key_bytes));
} }
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key; std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
size_t tuning_key_bytes = size_t tuning_key_bytes =
...@@ -809,6 +912,7 @@ struct mlir_program ...@@ -809,6 +912,7 @@ struct mlir_program
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
mlir_module mmodule; mlir_module mmodule;
mlir_logger logger;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch = ""; std::string target_arch = "";
......
...@@ -28,19 +28,20 @@ namespace migraphx { ...@@ -28,19 +28,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
double tolerance) const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double error; double rms_error;
passed = verify::verify_range(ref, target, tolerance, &error); passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
std::cout << "FAILED: " << name << std::endl; std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl; std::cout << "RMS Error: " << rms_error << std::endl;
if(ref.size() < 32) if(ref.size() < 32)
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name, ...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name,
return passed; return passed;
} }
bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::tolerance tols{rms_tol};
return verify_args(name, target_arg, ref_arg, tols);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -158,6 +158,31 @@ TEST_CASE(two_transpose_gather) ...@@ -158,6 +158,31 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape_lazy)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r =
m1.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r =
m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(standard_reshape) TEST_CASE(standard_reshape)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -173,8 +198,7 @@ TEST_CASE(standard_reshape) ...@@ -173,8 +198,7 @@ TEST_CASE(standard_reshape)
{ {
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r}); m2.add_return({r});
} }
......
...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "make_precompile_op.hpp"
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_ops{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(layernorm_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto create_program = [=](bool first_arg_layernorm) {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto layernorm_ins =
mm->add_instruction(make_precompile_op("gpu::prelayernorm"), add1, alloc_ins2);
std::vector<migraphx::instruction_ref> pw_inputs = {layernorm_ins, z};
if(not first_arg_layernorm)
{
pw_inputs = {z, layernorm_ins};
}
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", pw_inputs, single_pointwise("add"));
auto alloc_ins3 = mm->add_instruction(alloc);
pw_inputs.push_back(alloc_ins3);
auto add2 = mm->add_instruction(make_precompile_op("pointwise"), pw_inputs, {pw_add2});
mm->add_return({add2});
return p;
};
auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", {x, z}, single_pointwise("add"));
auto layernorm_ins = mm->add_instruction(
make_precompile_op("gpu::prelayernorm"), {add1, z, alloc_ins2}, {pw_add2});
mm->add_return({layernorm_ins});
return p;
};
{
migraphx::program p1 = create_program(true);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
{
migraphx::program p1 = create_program(false);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy) ...@@ -53,7 +53,6 @@ TEST_CASE(host_same_buffer_copy)
migraphx::parameter_map pp; migraphx::parameter_map pp;
std::vector<float> a_vec(ss.elements(), -1); std::vector<float> a_vec(ss.elements(), -1);
std::vector<float> b_vec(ss.elements(), 2); std::vector<float> b_vec(ss.elements(), 2);
std::vector<float> c_vec(ss.elements(), 0);
pp["a"] = migraphx::argument(ss, a_vec.data()); pp["a"] = migraphx::argument(ss, a_vec.data());
pp["b"] = migraphx::argument(ss, b_vec.data()); pp["b"] = migraphx::argument(ss, b_vec.data());
std::vector<float> gpu_result; std::vector<float> gpu_result;
...@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy) ...@@ -64,7 +63,8 @@ TEST_CASE(host_same_buffer_copy)
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> results_vector(ss.elements(), -1); std::vector<float> results_vector(ss.elements(), -1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(c_vec, results_vector)); std::vector<float> gold_vec(ss.elements(), 0);
EXPECT(migraphx::verify::verify_rms_range(results_vector, gold_vec));
} }
TEST_CASE(arguments_lifetime) TEST_CASE(arguments_lifetime)
......
...@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir) ...@@ -133,7 +133,8 @@ bool verify_mlir(const migraphx::module& mmlir)
auto inputs = generate_params(ref); auto inputs = generate_params(ref);
auto mlir = create_program_from_mlir(mmlir); auto mlir = create_program_from_mlir(mmlir);
return migraphx::verify_args("mlir", run_ref(ref, inputs), run_gpu(mlir, inputs)); return migraphx::verify_args_with_tolerance(
"mlir", run_gpu(mlir, inputs), migraphx::verify::expected{run_ref(ref, inputs)});
} }
TEST_CASE(conv) TEST_CASE(conv)
......
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
TEST_CASE(gpu_target_copy) TEST_CASE(gpu_target_copy)
{ {
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy) ...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final; std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify::verify_range(val_orig, val_final)); EXPECT(migraphx::verify::verify_rms_range(val_orig, val_final));
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization) ...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5)); EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.01}));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
} }
......
...@@ -24,16 +24,16 @@ ...@@ -24,16 +24,16 @@
#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP #define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <class F> template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p, migraphx::module_ref create_pointwise_module(migraphx::program& p,
migraphx::module_ref mm, const std::string& name,
const std::string& name, std::vector<migraphx::instruction_ref> inputs,
std::vector<migraphx::instruction_ref> inputs, F f)
F f)
{ {
auto* pm = p.create_module(name); auto* pm = p.create_module(name);
pm->set_bypass(); pm->set_bypass();
...@@ -44,6 +44,17 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p, ...@@ -44,6 +44,17 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
}); });
auto r = f(pm, params); auto r = f(pm, params);
pm->add_return({r}); pm->add_return({r});
return pm;
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = create_pointwise_module(p, name, inputs, f);
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm}); return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
} }
......
...@@ -47,7 +47,7 @@ TEST_CASE(averagepool_notset_test) ...@@ -47,7 +47,7 @@ TEST_CASE(averagepool_notset_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {12}; std::vector<float> gold = {12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(averagepool_nt_cip_test) TEST_CASE(averagepool_nt_cip_test)
...@@ -65,7 +65,7 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -65,7 +65,7 @@ TEST_CASE(averagepool_nt_cip_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {8.33333}; std::vector<float> gold = {8.33333};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_flat_test) TEST_CASE(batch_norm_flat_test)
...@@ -76,15 +76,15 @@ TEST_CASE(batch_norm_flat_test) ...@@ -76,15 +76,15 @@ TEST_CASE(batch_norm_flat_test)
migraphx::shape x_shape{migraphx::shape::float_type, {10}}; migraphx::shape x_shape{migraphx::shape::float_type, {10}};
migraphx::shape c_shape(migraphx::shape::float_type, {1}); migraphx::shape c_shape(migraphx::shape::float_type, {1});
std::vector<float> x_data = {1.6524342, std::vector<float> x_data = {1.6524342,
-0.51048076, -0.51048076,
0.32543048, 0.32543048,
2.4410043, 2.4410043,
2.0833702, 2.0833702,
0.44981122, 0.44981122,
1.0044622, 1.0044622,
-0.24006313, -0.24006313,
-0.43065986, -0.43065986,
0.07626268}; 0.07626268};
std::vector<float> scale_data = {-0.02927135}; std::vector<float> scale_data = {-0.02927135};
std::vector<float> bias_data = {0.42347777}; std::vector<float> bias_data = {0.42347777};
std::vector<float> mean_data = {-0.00449735}; std::vector<float> mean_data = {-0.00449735};
...@@ -111,7 +111,7 @@ TEST_CASE(batch_norm_flat_test) ...@@ -111,7 +111,7 @@ TEST_CASE(batch_norm_flat_test)
0.43305403, 0.43305403,
0.4408022, 0.4408022,
0.42019472}; 0.42019472};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_rank_2_test) TEST_CASE(batch_norm_rank_2_test)
...@@ -148,7 +148,7 @@ TEST_CASE(batch_norm_rank_2_test) ...@@ -148,7 +148,7 @@ TEST_CASE(batch_norm_rank_2_test)
9.89948504, 9.89948504,
9.89948504, 9.89948504,
12.72790933}; 12.72790933};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_1d_test) TEST_CASE(batch_norm_1d_test)
...@@ -184,7 +184,7 @@ TEST_CASE(batch_norm_1d_test) ...@@ -184,7 +184,7 @@ TEST_CASE(batch_norm_1d_test)
0.4927, 0.771, -1.956, -2.123, -0.664, -0.583, -0.7207, -0.5127}; 0.4927, 0.771, -1.956, -2.123, -0.664, -0.583, -0.7207, -0.5127};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_2d_test) TEST_CASE(batch_norm_2d_test)
...@@ -250,7 +250,7 @@ TEST_CASE(batch_norm_2d_test) ...@@ -250,7 +250,7 @@ TEST_CASE(batch_norm_2d_test)
-2.76707697e+00, 1.47579327e+01, 4.94736385e+00, 2.68847847e+01, -6.49254417e+00, -2.76707697e+00, 1.47579327e+01, 4.94736385e+00, 2.68847847e+01, -6.49254417e+00,
1.94286156e+00, -7.19223642e+00, -3.70413971e+00, -4.04303551e-01, -1.01827660e+01, 1.94286156e+00, -7.19223642e+00, -3.70413971e+00, -4.04303551e-01, -1.01827660e+01,
1.49476433e+00}; 1.49476433e+00};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(batch_norm_3d_test) TEST_CASE(batch_norm_3d_test)
...@@ -292,7 +292,7 @@ TEST_CASE(batch_norm_3d_test) ...@@ -292,7 +292,7 @@ TEST_CASE(batch_norm_3d_test)
6.098, 11.03, 2.81, 2.81, 2.81, 12.125, 3.143, 8.53, 17.52, 4.938, 15.71, 6.098, 11.03, 2.81, 2.81, 2.81, 12.125, 3.143, 8.53, 17.52, 4.938, 15.71,
1.347, 4.938, 1.167, 6.098, 12.67, 12.67, 4.453, 4.453, -0.4768, 12.67}; 1.347, 4.938, 1.167, 6.098, 12.67, 12.67, 4.453, 4.453, -0.4768, 12.67};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(celu_verify_test) TEST_CASE(celu_verify_test)
...@@ -309,12 +309,12 @@ TEST_CASE(celu_verify_test) ...@@ -309,12 +309,12 @@ TEST_CASE(celu_verify_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct(6); std::vector<float> gold(6);
float alpha = 0.5; float alpha = 0.5;
std::transform(data.begin(), data.end(), correct.begin(), [&](auto x) { std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha)); return std::max(0.0f, x) + std::min(0.0f, alpha * std::expm1(x / alpha));
}); });
EXPECT(migraphx::verify::verify_range(result_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(clip_args_type_mismatch) TEST_CASE(clip_args_type_mismatch)
...@@ -330,7 +330,7 @@ TEST_CASE(clip_args_type_mismatch) ...@@ -330,7 +330,7 @@ TEST_CASE(clip_args_type_mismatch)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7}; std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(depthtospace_simple_test) TEST_CASE(depthtospace_simple_test)
...@@ -348,7 +348,7 @@ TEST_CASE(depthtospace_simple_test) ...@@ -348,7 +348,7 @@ TEST_CASE(depthtospace_simple_test)
std::vector<float> gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16, std::vector<float> gold = {0, 12, 1, 13, 2, 14, 24, 36, 25, 37, 26, 38, 3, 15, 4, 16,
5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42, 5, 17, 27, 39, 28, 40, 29, 41, 6, 18, 7, 19, 8, 20, 30, 42,
31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47}; 31, 43, 32, 44, 9, 21, 10, 22, 11, 23, 33, 45, 34, 46, 35, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(spacetodepth_simple_test) TEST_CASE(spacetodepth_simple_test)
...@@ -366,7 +366,7 @@ TEST_CASE(spacetodepth_simple_test) ...@@ -366,7 +366,7 @@ TEST_CASE(spacetodepth_simple_test)
std::vector<float> gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13, std::vector<float> gold = {0, 2, 4, 12, 14, 16, 24, 26, 28, 36, 38, 40, 1, 3, 5, 13,
15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32, 15, 17, 25, 27, 29, 37, 39, 41, 6, 8, 10, 18, 20, 22, 30, 32,
34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47}; 34, 42, 44, 46, 7, 9, 11, 19, 21, 23, 31, 33, 35, 43, 45, 47};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(spacetodepth_depthtospace_test) TEST_CASE(spacetodepth_depthtospace_test)
...@@ -374,11 +374,11 @@ TEST_CASE(spacetodepth_depthtospace_test) ...@@ -374,11 +374,11 @@ TEST_CASE(spacetodepth_depthtospace_test)
// space to depth // space to depth
auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx"); auto p1 = migraphx::parse_onnx("spacetodepth_simple_test.onnx");
p1.compile(migraphx::make_target("ref")); p1.compile(migraphx::make_target("ref"));
std::vector<float> data_in(48); std::vector<float> gold_data_in(48);
std::iota(std::begin(data_in), std::end(data_in), 0); std::iota(std::begin(gold_data_in), std::end(gold_data_in), 0);
migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}}; migraphx::shape s_x_1{migraphx::shape::float_type, {1, 2, 4, 6}};
migraphx::parameter_map pp1; migraphx::parameter_map pp1;
pp1["x"] = migraphx::argument(s_x_1, data_in.data()); pp1["x"] = migraphx::argument(s_x_1, gold_data_in.data());
auto result1 = p1.eval(pp1).back(); auto result1 = p1.eval(pp1).back();
// depth to space // depth to space
auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx"); auto p2 = migraphx::parse_onnx("depthtospace_simple_test.onnx");
...@@ -388,7 +388,7 @@ TEST_CASE(spacetodepth_depthtospace_test) ...@@ -388,7 +388,7 @@ TEST_CASE(spacetodepth_depthtospace_test)
auto result2 = p2.eval(pp2).back(); auto result2 = p2.eval(pp2).back();
std::vector<float> result_vector2; std::vector<float> result_vector2;
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vector2, data_in)); EXPECT(migraphx::verify::verify_rms_range(result_vector2, gold_data_in));
} }
TEST_CASE(eyelike_verify_test) TEST_CASE(eyelike_verify_test)
...@@ -405,8 +405,8 @@ TEST_CASE(eyelike_verify_test) ...@@ -405,8 +405,8 @@ TEST_CASE(eyelike_verify_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.}; std::vector<float> gold_eyelike_mat = {0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
} }
TEST_CASE(eyelike_verify_negk_test) TEST_CASE(eyelike_verify_negk_test)
...@@ -423,8 +423,8 @@ TEST_CASE(eyelike_verify_negk_test) ...@@ -423,8 +423,8 @@ TEST_CASE(eyelike_verify_negk_test)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.}; std::vector<float> gold_eyelike_mat = {0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.};
EXPECT(migraphx::verify::verify_range(result_vector, eyelike_mat)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold_eyelike_mat));
} }
TEST_CASE(gather_elements) TEST_CASE(gather_elements)
...@@ -447,7 +447,7 @@ TEST_CASE(gather_elements) ...@@ -447,7 +447,7 @@ TEST_CASE(gather_elements)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375}; std::vector<float> gold = {-0.125, 0.5625, -0.9375, 0.25, 0.5625, 0.9375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(gemm_test) TEST_CASE(gemm_test)
...@@ -491,7 +491,7 @@ TEST_CASE(gemm_test) ...@@ -491,7 +491,7 @@ TEST_CASE(gemm_test)
0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702, 0.8098607, 1.2157929, 1.1010075, 1.0706307, 1.0429881, 1.1771785, 1.2362702,
0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703, 0.8239243, 1.1112559, 0.9639262, 1.0813537, 0.8825792, 1.121141, 1.1885703,
1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242}; 1.2227502, 1.4568202, 1.1388762, 1.55058, 1.0958102, 1.4637487, 1.5756242};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(gemm_half_test) TEST_CASE(gemm_half_test)
...@@ -535,7 +535,7 @@ TEST_CASE(gemm_half_test) ...@@ -535,7 +535,7 @@ TEST_CASE(gemm_half_test)
2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155, 2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155,
1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394}; 1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394};
std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()}; std::vector<migraphx::half> gold{tmp.cbegin(), tmp.cend()};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(greaterorequal_test) TEST_CASE(greaterorequal_test)
...@@ -556,7 +556,7 @@ TEST_CASE(greaterorequal_test) ...@@ -556,7 +556,7 @@ TEST_CASE(greaterorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0, 1.0, 0.0}; std::vector<float> gold = {1.0, 1.0, 0.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(hardsigmoid_verify_test) TEST_CASE(hardsigmoid_verify_test)
...@@ -580,7 +580,7 @@ TEST_CASE(hardsigmoid_verify_test) ...@@ -580,7 +580,7 @@ TEST_CASE(hardsigmoid_verify_test)
std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) { std::transform(data.begin(), data.end(), gold.begin(), [&](auto x) {
return std::max(0.0f, std::min(x * alpha + beta, 1.0f)); return std::max(0.0f, std::min(x * alpha + beta, 1.0f));
}); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_else_test) TEST_CASE(if_else_test)
...@@ -602,7 +602,7 @@ TEST_CASE(if_else_test) ...@@ -602,7 +602,7 @@ TEST_CASE(if_else_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683}; std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_else_test_inlined) TEST_CASE(if_else_test_inlined)
...@@ -621,7 +621,7 @@ TEST_CASE(if_else_test_inlined) ...@@ -621,7 +621,7 @@ TEST_CASE(if_else_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472}; std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_test) TEST_CASE(if_then_test)
...@@ -644,7 +644,7 @@ TEST_CASE(if_then_test) ...@@ -644,7 +644,7 @@ TEST_CASE(if_then_test)
// onnx adds ones so result should be just + 1.0 // onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_test_inlined) TEST_CASE(if_then_test_inlined)
...@@ -663,7 +663,7 @@ TEST_CASE(if_then_test_inlined) ...@@ -663,7 +663,7 @@ TEST_CASE(if_then_test_inlined)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375}; std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_literal_test) TEST_CASE(if_literal_test)
...@@ -688,14 +688,14 @@ TEST_CASE(if_literal_test) ...@@ -688,14 +688,14 @@ TEST_CASE(if_literal_test)
{ {
auto result_vector = run_prog(true); auto result_vector = run_prog(true);
std::vector<float> gold = {1, 2, 3, 4, 5}; std::vector<float> gold = {1, 2, 3, 4, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
// else branch // else branch
{ {
auto result_vector = run_prog(false); auto result_vector = run_prog(false);
std::vector<float> gold = {5, 4, 3, 2, 1}; std::vector<float> gold = {5, 4, 3, 2, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
} }
...@@ -726,7 +726,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test) ...@@ -726,7 +726,7 @@ TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
std::vector<float> gold = { std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125}; 1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_then_else_multi_output_shapes_test) TEST_CASE(if_then_else_multi_output_shapes_test)
...@@ -757,7 +757,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test) ...@@ -757,7 +757,7 @@ TEST_CASE(if_then_else_multi_output_shapes_test)
std::vector<float> gold = { std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125}; 1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
...@@ -789,14 +789,14 @@ TEST_CASE(if_pl_test) ...@@ -789,14 +789,14 @@ TEST_CASE(if_pl_test)
{ {
auto result_vector = run_prog(true); auto result_vector = run_prog(true);
std::vector<float> gold = {2, 3, 4, 5, 6, 7}; std::vector<float> gold = {2, 3, 4, 5, 6, 7};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
// else branch // else branch
{ {
auto result_vector = run_prog(false); auto result_vector = run_prog(false);
std::vector<float> gold = {1, 2, 3, 4, 5, 6}; std::vector<float> gold = {1, 2, 3, 4, 5, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
} }
...@@ -835,8 +835,8 @@ TEST_CASE(if_tuple_test) ...@@ -835,8 +835,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(true); auto results = run_prog(true);
std::vector<float> gold0(4, 2.0f); std::vector<float> gold0(4, 2.0f);
std::vector<float> gold1(12, 4.0f); std::vector<float> gold1(12, 4.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0)); EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1)); EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
} }
// else branch // else branch
...@@ -844,8 +844,8 @@ TEST_CASE(if_tuple_test) ...@@ -844,8 +844,8 @@ TEST_CASE(if_tuple_test)
auto results = run_prog(false); auto results = run_prog(false);
std::vector<float> gold0(4, 3.0f); std::vector<float> gold0(4, 3.0f);
std::vector<float> gold1(12, 5.0f); std::vector<float> gold1(12, 5.0f);
EXPECT(migraphx::verify::verify_range(results.at(0), gold0)); EXPECT(migraphx::verify::verify_rms_range(results.at(0), gold0));
EXPECT(migraphx::verify::verify_range(results.at(1), gold1)); EXPECT(migraphx::verify::verify_rms_range(results.at(1), gold1));
} }
} }
...@@ -876,7 +876,7 @@ TEST_CASE(instance_norm_test) ...@@ -876,7 +876,7 @@ TEST_CASE(instance_norm_test)
2.54919, 2.54919,
3.32379, 3.32379,
4.09838}; 4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(instance_norm_dyn_batch_test) TEST_CASE(instance_norm_dyn_batch_test)
...@@ -918,7 +918,7 @@ TEST_CASE(instance_norm_dyn_batch_test) ...@@ -918,7 +918,7 @@ TEST_CASE(instance_norm_dyn_batch_test)
2.54919, 2.54919,
3.32379, 3.32379,
4.09838}; 4.09838};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(instance_norm_3d_test) TEST_CASE(instance_norm_3d_test)
...@@ -947,7 +947,7 @@ TEST_CASE(instance_norm_3d_test) ...@@ -947,7 +947,7 @@ TEST_CASE(instance_norm_3d_test)
3.18218, 3.18218,
4.05505}; 4.05505};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lessorequal_test) TEST_CASE(lessorequal_test)
...@@ -968,7 +968,7 @@ TEST_CASE(lessorequal_test) ...@@ -968,7 +968,7 @@ TEST_CASE(lessorequal_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 0, 1}; std::vector<float> gold = {1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_1norm) TEST_CASE(lpnormalization_1norm)
...@@ -996,7 +996,7 @@ TEST_CASE(lpnormalization_1norm) ...@@ -996,7 +996,7 @@ TEST_CASE(lpnormalization_1norm)
3.f / 7.f, 3.f / 7.f,
0.f, 0.f,
0.f}; 0.f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(lpnormalization_2norm) TEST_CASE(lpnormalization_2norm)
...@@ -1012,19 +1012,19 @@ TEST_CASE(lpnormalization_2norm) ...@@ -1012,19 +1012,19 @@ TEST_CASE(lpnormalization_2norm)
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> correct{0.f, std::vector<float> gold{0.f,
2.f / 3.f, 2.f / 3.f,
-2.f / 3.f, -2.f / 3.f,
1.f / 3.f, 1.f / 3.f,
1.f / 6.f, 1.f / 6.f,
-5.f / 6.f, -5.f / 6.f,
3.f / 6.f, 3.f / 6.f,
-1.f / 6.f, -1.f / 6.f,
-4.f / 5.f, -4.f / 5.f,
3.f / 5.f, 3.f / 5.f,
0.f, 0.f,
0.f}; 0.f};
EXPECT(migraphx::verify::verify_range(result_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_broadcast_test) TEST_CASE(mean_broadcast_test)
...@@ -1055,7 +1055,7 @@ TEST_CASE(mean_broadcast_test) ...@@ -1055,7 +1055,7 @@ TEST_CASE(mean_broadcast_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(24, 3); std::vector<float> gold(24, 3);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_test) TEST_CASE(mean_test)
...@@ -1082,7 +1082,7 @@ TEST_CASE(mean_test) ...@@ -1082,7 +1082,7 @@ TEST_CASE(mean_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data; const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data;
std::vector<double> gold(num_elms, mean); std::vector<double> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mean_integral_test) TEST_CASE(mean_integral_test)
...@@ -1109,7 +1109,7 @@ TEST_CASE(mean_integral_test) ...@@ -1109,7 +1109,7 @@ TEST_CASE(mean_integral_test)
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data; const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data;
std::vector<int> gold(num_elms, mean); std::vector<int> gold(num_elms, mean);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test) TEST_CASE(mod_test)
...@@ -1136,7 +1136,7 @@ TEST_CASE(mod_test) ...@@ -1136,7 +1136,7 @@ TEST_CASE(mod_test)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5}; 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_different_types) TEST_CASE(mod_test_different_types)
...@@ -1164,7 +1164,7 @@ TEST_CASE(mod_test_different_types) ...@@ -1164,7 +1164,7 @@ TEST_CASE(mod_test_different_types)
std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, std::vector<int32_t> gold = {0, -2, 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2,
5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5}; 5, 0, 2, 3, 0, -2, 5, 0, 2, 3, 0, -2, 5};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_fmod) TEST_CASE(mod_test_fmod)
...@@ -1193,7 +1193,7 @@ TEST_CASE(mod_test_fmod) ...@@ -1193,7 +1193,7 @@ TEST_CASE(mod_test_fmod)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2, 10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1}; 7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(mod_test_fmod_different_types) TEST_CASE(mod_test_fmod_different_types)
...@@ -1223,7 +1223,7 @@ TEST_CASE(mod_test_fmod_different_types) ...@@ -1223,7 +1223,7 @@ TEST_CASE(mod_test_fmod_different_types)
10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2, 10.7, 11.2, 12.3, 13.9, -14.2, 15.8, 1.6, 3.9, 5.2,
7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1}; 7.0, 9.0, 1.0, -4.0, 7.0, -3.0, 1.2, 1.3, 3.1};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
...@@ -1242,7 +1242,7 @@ TEST_CASE(nonzero_test) ...@@ -1242,7 +1242,7 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 0, 1, 0, 0, 1, 0, 0}; std::vector<float> gold = {0, 0, 1, 0, 0, 1, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_downsample_f_test) TEST_CASE(resize_downsample_f_test)
...@@ -1263,7 +1263,7 @@ TEST_CASE(resize_downsample_f_test) ...@@ -1263,7 +1263,7 @@ TEST_CASE(resize_downsample_f_test)
std::vector<float> gold = {0.0f, 3.0f}; std::vector<float> gold = {0.0f, 3.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_linear_ac_test) TEST_CASE(resize_upsample_linear_ac_test)
...@@ -1298,7 +1298,7 @@ TEST_CASE(resize_upsample_linear_ac_test) ...@@ -1298,7 +1298,7 @@ TEST_CASE(resize_upsample_linear_ac_test)
11.0f / 3, 11.0f / 3,
4}; 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_linear_test) TEST_CASE(resize_upsample_linear_test)
...@@ -1319,7 +1319,7 @@ TEST_CASE(resize_upsample_linear_test) ...@@ -1319,7 +1319,7 @@ TEST_CASE(resize_upsample_linear_test)
std::vector<float> gold = { std::vector<float> gold = {
1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4}; 1, 1.25, 1.75, 2, 1.5, 1.75, 2.25, 2.5, 2.5, 2.75, 3.25, 3.5, 3, 3.25, 3.75, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(resize_upsample_pf_test) TEST_CASE(resize_upsample_pf_test)
...@@ -1340,7 +1340,7 @@ TEST_CASE(resize_upsample_pf_test) ...@@ -1340,7 +1340,7 @@ TEST_CASE(resize_upsample_pf_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_4D_verify_test) TEST_CASE(reversesequence_4D_verify_test)
...@@ -1361,7 +1361,7 @@ TEST_CASE(reversesequence_4D_verify_test) ...@@ -1361,7 +1361,7 @@ TEST_CASE(reversesequence_4D_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0}; 8.0, 9.0, 10.0, 11.0, 4.0, 5.0, 6.0, 7.0, 0.0, 1.0, 2.0, 3.0, 12.0, 13.0, 14.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_batch_verify_test) TEST_CASE(reversesequence_batch_verify_test)
...@@ -1382,7 +1382,7 @@ TEST_CASE(reversesequence_batch_verify_test) ...@@ -1382,7 +1382,7 @@ TEST_CASE(reversesequence_batch_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0}; 0.0, 1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 10.0, 9.0, 8.0, 11.0, 15.0, 14.0, 13.0, 12.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(reversesequence_time_verify_test) TEST_CASE(reversesequence_time_verify_test)
...@@ -1403,7 +1403,7 @@ TEST_CASE(reversesequence_time_verify_test) ...@@ -1403,7 +1403,7 @@ TEST_CASE(reversesequence_time_verify_test)
std::vector<float> gold = { std::vector<float> gold = {
3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0}; 3.0, 6.0, 9.0, 12.0, 2.0, 5.0, 8.0, 13.0, 1.0, 4.0, 10.0, 14.0, 0.0, 7.0, 11.0, 15.0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(selu_test) TEST_CASE(selu_test)
...@@ -1423,7 +1423,7 @@ TEST_CASE(selu_test) ...@@ -1423,7 +1423,7 @@ TEST_CASE(selu_test)
std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6}; std::vector<float> gold = {0.55, 1.05, 0, -0.10912, -0.149251, 6};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(size_verify_test) TEST_CASE(size_verify_test)
...@@ -1457,7 +1457,7 @@ TEST_CASE(slice_test) ...@@ -1457,7 +1457,7 @@ TEST_CASE(slice_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3}; std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_5arg_test) TEST_CASE(slice_5arg_test)
...@@ -1477,7 +1477,7 @@ TEST_CASE(slice_5arg_test) ...@@ -1477,7 +1477,7 @@ TEST_CASE(slice_5arg_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18}; std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_reverse_test) TEST_CASE(slice_reverse_test)
...@@ -1497,7 +1497,7 @@ TEST_CASE(slice_reverse_test) ...@@ -1497,7 +1497,7 @@ TEST_CASE(slice_reverse_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16}; std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(slice_step_test) TEST_CASE(slice_step_test)
...@@ -1517,7 +1517,7 @@ TEST_CASE(slice_step_test) ...@@ -1517,7 +1517,7 @@ TEST_CASE(slice_step_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12}; std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(softplus_test) TEST_CASE(softplus_test)
...@@ -1538,7 +1538,7 @@ TEST_CASE(softplus_test) ...@@ -1538,7 +1538,7 @@ TEST_CASE(softplus_test)
std::transform( std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); }); data.begin(), data.end(), gold.begin(), [](auto x) { return std::log1p(std::exp(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(softsign_test) TEST_CASE(softsign_test)
...@@ -1559,7 +1559,7 @@ TEST_CASE(softsign_test) ...@@ -1559,7 +1559,7 @@ TEST_CASE(softsign_test)
std::transform( std::transform(
data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); }); data.begin(), data.end(), gold.begin(), [](auto x) { return x / (1.0 + std::abs(x)); });
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(upsample_test) TEST_CASE(upsample_test)
...@@ -1578,7 +1578,7 @@ TEST_CASE(upsample_test) ...@@ -1578,7 +1578,7 @@ TEST_CASE(upsample_test)
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4}; 3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(where_test) TEST_CASE(where_test)
...@@ -1620,7 +1620,7 @@ TEST_CASE(where_test) ...@@ -1620,7 +1620,7 @@ TEST_CASE(where_test)
2.0f, 2.0f,
1.0f, 1.0f,
2.0f}; 2.0f};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p) std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::program& p)
...@@ -1645,7 +1645,7 @@ TEST_CASE(trilu_test) ...@@ -1645,7 +1645,7 @@ TEST_CASE(trilu_test)
std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12}; std::vector<float> gold = {1, 2, 3, 4, 0, 6, 7, 8, 0, 0, 11, 12};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_batch_diff_k_test) TEST_CASE(trilu_batch_diff_k_test)
...@@ -1656,7 +1656,7 @@ TEST_CASE(trilu_batch_diff_k_test) ...@@ -1656,7 +1656,7 @@ TEST_CASE(trilu_batch_diff_k_test)
std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0}; std::vector<float> gold = {0, 0, 3, 0, 0, 0, 0, 0, 9, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_lower_test) TEST_CASE(trilu_lower_test)
...@@ -1667,7 +1667,7 @@ TEST_CASE(trilu_lower_test) ...@@ -1667,7 +1667,7 @@ TEST_CASE(trilu_lower_test)
std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0}; std::vector<float> gold = {0, 0, 0, 0, 5, 0, 0, 0, 9, 10, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_out_k_test) TEST_CASE(trilu_out_k_test)
...@@ -1678,7 +1678,7 @@ TEST_CASE(trilu_out_k_test) ...@@ -1678,7 +1678,7 @@ TEST_CASE(trilu_out_k_test)
std::vector<float> gold(12, 0); std::vector<float> gold(12, 0);
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_row_one_test) TEST_CASE(trilu_row_one_test)
...@@ -1689,7 +1689,7 @@ TEST_CASE(trilu_row_one_test) ...@@ -1689,7 +1689,7 @@ TEST_CASE(trilu_row_one_test)
std::vector<float> gold = {0, 2, 3, 4}; std::vector<float> gold = {0, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment