Unverified Commit 77df49b8 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jenkins_reorder

parents e4e19b1d 60b8b097
...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction ...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction
const std::vector<module_ref>& module_inputs() const; const std::vector<module_ref>& module_inputs() const;
/// Where this instruction is used as an input to another instruction
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y); friend bool operator==(const instruction& x, const instruction& y);
......
...@@ -49,17 +49,22 @@ struct allocate ...@@ -49,17 +49,22 @@ struct allocate
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
migraphx::check_shapes{inputs, *this, true}.has(0, 1);
// check if shape attribute is not default
if(s != shape()) if(s != shape())
{ {
if(inputs.size() == 1)
{
migraphx::check_shapes{inputs, *this, false}.only_dims(1);
}
else
{
migraphx::check_shapes{inputs, *this, false}.has(0);
}
return s; return s;
} }
else else
{ {
migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
const auto& out_dims = inputs.at(0); const auto& out_dims = inputs.at(0);
assert(not out_dims.dynamic());
assert(out_dims.ndim() == 1);
std::size_t max_val = std::numeric_limits<std::size_t>::max(); std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,8 +46,6 @@ struct reshape ...@@ -45,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape dyn_compute_shape(shape s0) const shape dyn_compute_shape(shape s0) const
...@@ -110,27 +109,9 @@ struct reshape ...@@ -110,27 +109,9 @@ struct reshape
return it; return it;
} }
template <class DimIterator, class StrideIterator> // This will attempt to alias the dimensions of the input shape to the lens of
static auto can_strides_merge(DimIterator dim_start, // `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
DimIterator dim_last, // can remove previous nullopts that were sent back for the alias case
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims) static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{ {
if(input.standard()) if(input.standard())
...@@ -155,13 +136,8 @@ struct reshape ...@@ -155,13 +136,8 @@ struct reshape
{ {
auto start = idims.begin() + i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((i + n) <= istrides.size()); assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
...@@ -170,8 +146,7 @@ struct reshape ...@@ -170,8 +146,7 @@ struct reshape
{ {
auto start = rdims.begin() + i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((r + n) <= rdims.size()); assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
...@@ -191,15 +166,11 @@ struct reshape ...@@ -191,15 +166,11 @@ struct reshape
auto stride = rstrides.back(); auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if(d != 1) (void)d;
return nullopt;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides}; return shape{input.type(), rdims, rstrides};
} }
...@@ -233,25 +204,24 @@ struct reshape ...@@ -233,25 +204,24 @@ struct reshape
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements()) if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s; return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
auto s0 = inputs.front();
if(s0.dynamic()) if(s0.dynamic())
{ {
return dyn_compute_shape(s0); return dyn_compute_shape(s0);
...@@ -264,10 +234,14 @@ struct reshape ...@@ -264,10 +234,14 @@ struct reshape
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); assert(dyn_out.computed_shape.standard());
} argument result{dyn_out.computed_shape};
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
}; };
} // namespace op } // namespace op
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#define MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reshape_lazy
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape_lazy"; }
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("reshape_lazy: Only supports one non-fixed dynamic_dimension");
}
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"reshape_lazy: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("reshape_lazy: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_lazy_dims(const shape& input,
const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
auto s = reshape_lazy_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify { namespace verify {
...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_rms_error != nullptr)
*out_error = error; *out_rms_error = error;
return error <= threshold; return error <= threshold;
} }
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify } // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
const argument& target_arg, verify::tolerance);
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -49,6 +49,8 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
{ {
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!"); MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
} }
// convert to a scalar literal
l_val = literal(shape{l_val.get_shape().type(), {1}, {0}}, l_val.data());
} }
else else
{ {
...@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape> ...@@ -64,30 +66,37 @@ struct parse_constant_of_shape : op_parser<parse_constant_of_shape>
migraphx::shape s; migraphx::shape s;
// input is empty, output is a scalar // input is empty, output is a scalar
auto type = l_val.get_shape().type(); auto type = l_val.get_shape().type();
// empty input tensor, output is a scalar migraphx::argument input = args[0]->eval();
if(args[0]->get_shape().elements() == 0) if(not input.empty())
{ {
s = migraphx::shape{type, {1}, {0}}; // empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
std::vector<std::size_t> dims;
input.visit([&](auto ia) { dims.assign(ia.begin(), ia.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
// has variable input (dynamic shape buffer)
else else
{ {
migraphx::argument in = args[0]->eval(); auto dv_lit = info.add_literal(l_val);
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported"); auto alloc_ins =
info.add_instruction(make_op("allocate", {{"buf_type", type}}), args[0]);
std::vector<std::size_t> dims; return info.add_instruction(make_op("fill"), dv_lit, alloc_ins);
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
} }
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), val.front());
l_out = literal(s, out_vec);
});
return info.add_literal(l_out);
} }
} }
}; };
......
...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const ...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const
auto lens = s.lens(); auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::vector<std::int64_t> axes(lens.size() - 2);
std::int64_t c = s.lens()[1]; std::iota(axes.begin(), axes.end(), 2);
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
} }
// max pooling // max pooling
else else
{ {
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
} }
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
...@@ -122,6 +122,11 @@ struct find_nop_reshapes ...@@ -122,6 +122,11 @@ struct find_nop_reshapes
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
reshapes.insert("transpose"); reshapes.insert("transpose");
reshapes.insert("reduce_mean");
reshapes.insert("reduce_max");
reshapes.insert("reduce_min");
reshapes.insert("reduce_sum");
reshapes.insert("reduce_prod");
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
......
...@@ -23,6 +23,10 @@ ...@@ -23,6 +23,10 @@
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -194,7 +198,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -194,7 +198,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR ON CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR # Find package rocMLIR
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
...@@ -92,7 +93,7 @@ struct hiprtc_program ...@@ -92,7 +93,7 @@ struct hiprtc_program
{ {
struct string_array struct string_array
{ {
std::vector<std::string> strings{}; std::deque<std::string> strings{};
std::vector<const char*> c_strs{}; std::vector<const char*> c_strs{};
string_array() {} string_array() {}
...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
......
...@@ -81,6 +81,14 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -81,6 +81,14 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
using f_type = decltype(f); using f_type = decltype(f);
dim3 nblocks(global / local); dim3 nblocks(global / local);
dim3 nthreads(local); dim3 nthreads(local);
/*
hipGetLastError() returns error for the first failed HIP call that happened previously.
MIGraphX calls into various backend libraries and failed HIP calls can also happen there.
Calling hipGetLastError() would reset error code to hipSuccess, so that inside MIGraphX
failed call to hipLaunchKernelGGL() can be captured.
*/
hipError_t flush_call = hipGetLastError();
(void)(flush_call);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError(); hipError_t kernel_launch_status = hipGetLastError();
......
...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops ...@@ -283,9 +283,9 @@ struct find_mlir_fused_ops
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&, &anchor_op = anchor_op](auto name, auto input) { [&, &anchor = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), anchor_op); return std::make_pair(pm->get_parameter(name), anchor);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op ...@@ -327,12 +327,12 @@ struct find_mlir_standalone_op
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("convolution"); } auto matcher() const { return is_mlir_conv; }
}; };
struct find_mlir_standalone_dot_op : find_mlir_standalone_op struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{ {
auto matcher() const { return match::name("dot"); } auto matcher() const { return match::any_of(match::name("dot"), match::name("quant_dot")); }
}; };
/** /**
...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx) ...@@ -365,7 +365,7 @@ bool is_enabled(std::string_view op_name, context* ctx)
{ {
return true; return true;
} }
else if(op_name == "convolution") else if(op_name == "convolution" or op_name == "quant_convolution")
{ {
if(ctx == nullptr) if(ctx == nullptr)
{ {
......
...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise ...@@ -790,22 +790,26 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)( return precompile_name("pointwise")(match::any_of[match::inputs()](
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto pw_ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front(); auto* pm = pw_ins->module_inputs().front();
auto pw_inputs = pw_ins->inputs();
auto ln_pos = std::find(pw_inputs.begin(), pw_inputs.end(), layernorm);
assert(ln_pos != pw_inputs.end());
pw_inputs.erase(ln_pos);
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), pw_inputs.begin(), pw_inputs.end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(pw_ins, layernorm->get_operator(), inputs, {pm});
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp> #include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reshape_lazy.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
...@@ -89,7 +90,6 @@ struct miopen_apply ...@@ -89,7 +90,6 @@ struct miopen_apply
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
add_extend_op("argmax"); add_extend_op("argmax");
add_extend_op("argmin"); add_extend_op("argmin");
add_extend_op("logsoftmax"); add_extend_op("logsoftmax");
...@@ -115,6 +115,7 @@ struct miopen_apply ...@@ -115,6 +115,7 @@ struct miopen_apply
add_neg_op(); add_neg_op();
add_nms_op(); add_nms_op();
add_select_module_op(); add_select_module_op();
add_reshape_lazy_op();
} }
void copy_params() const void copy_params() const
...@@ -376,6 +377,32 @@ struct miopen_apply ...@@ -376,6 +377,32 @@ struct miopen_apply
return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs()); return mod->replace_instruction(ins, ins->get_operator(), inputs, ins->module_inputs());
}); });
} }
/**
* Adds reshape lazy to reshape ops that can be aliased instead of copied.
* `gpu::contiguous` are added before and after the reshape; these contiguous
* instructions can be removed by the eliminate_contiguous pass.
*/
void add_reshape_lazy_op()
{
apply_map.emplace("reshape", [=](instruction_ref ins) {
std::vector<instruction_ref> before_contiguous_args = ins->inputs();
auto before_alloc = insert_allocation(ins, std::prev(ins)->get_shape());
before_contiguous_args.push_back(before_alloc);
auto before_contig =
mod->insert_instruction(ins, make_op("gpu::contiguous"), {before_contiguous_args});
auto new_lazy_reshape = mod->insert_instruction(
ins,
make_op("reshape_lazy", {{"dims", {ins->get_operator().to_value().at("dims")}}}),
before_contig);
std::vector<instruction_ref> after_contiguous_args = {new_lazy_reshape};
auto after_alloc = insert_allocation(new_lazy_reshape, new_lazy_reshape->get_shape());
after_contiguous_args.push_back(after_alloc);
return mod->replace_instruction(ins, make_op("gpu::contiguous"), after_contiguous_args);
});
}
}; };
void lowering::apply(module_pass_manager& mpm) const void lowering::apply(module_pass_manager& mpm) const
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#include <ostream>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
#include <mlir-c/IR.h> #include <mlir-c/IR.h>
...@@ -34,6 +35,7 @@ ...@@ -34,6 +35,7 @@
#include <mlir-c/Dialect/Rock.h> #include <mlir-c/Dialect/Rock.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x) ...@@ -180,13 +182,85 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
struct mlir_logger
{
std::stringstream ss;
mlir_context* ctx;
std::optional<MlirDiagnosticHandlerID> id;
mlir_logger() : ctx(nullptr), id(std::nullopt) {}
mlir_logger(mlir_context* context) : ctx(context)
{
id =
mlirContextAttachDiagnosticHandler(ctx->get(), mlir_diagnostic_print_cb, this, nullptr);
}
~mlir_logger()
{
if(id.has_value())
mlirContextDetachDiagnosticHandler(ctx->get(), *id);
}
mlir_logger(const mlir_logger& other) = delete;
mlir_logger& operator=(const mlir_logger& other) = delete;
mlir_logger(mlir_logger&& other) noexcept
: ss(std::move(other.ss)), ctx(other.ctx), id(other.id)
{
other.ctx = nullptr;
other.id = std::nullopt;
}
mlir_logger& operator=(mlir_logger other) noexcept
{
std::swap(ss, other.ss);
std::swap(ctx, other.ctx);
std::swap(id, other.id);
return *this;
}
std::string str() const { return ss.str(); }
void clear() { ss = std::stringstream{}; }
static MlirLogicalResult mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger);
MlirLogicalResult handle(MlirDiagnostic diag);
};
MlirLogicalResult mlir_logger::mlir_diagnostic_print_cb(MlirDiagnostic diag, void* logger)
{
return reinterpret_cast<mlir_logger*>(logger)->handle(diag);
}
MlirLogicalResult mlir_logger::handle(MlirDiagnostic diag)
{
MlirDiagnosticSeverity sev = mlirDiagnosticGetSeverity(diag);
switch(sev)
{
case MlirDiagnosticSeverity::MlirDiagnosticError: ss << "Error: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticWarning: ss << "Warning: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticNote: ss << "Note: "; break;
case MlirDiagnosticSeverity::MlirDiagnosticRemark: ss << "Remark: "; break;
}
mlir_print(mlirDiagnosticPrint, diag, [&](auto s) { ss << s; });
ss << std::endl;
for(intptr_t i = 0, e = mlirDiagnosticGetNumNotes(diag); i < e; ++i)
{
(void)handle(mlirDiagnosticGetNote(diag, i));
}
return mlirLogicalResultSuccess();
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
: ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(), : ctx(mlirContextCreateWithRegistry(get_dialect_registry().get(),
/*threadingEnable=*/false)), /*threadingEnable=*/false)),
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location)),
logger(&ctx)
{ {
mlirContextSetThreadPool(ctx.get(), get_thread_pool().get()); mlirContextSetThreadPool(ctx.get(), get_thread_pool().get());
mlirContextLoadAllAvailableDialects(ctx.get()); mlirContextLoadAllAvailableDialects(ctx.get());
...@@ -614,21 +688,49 @@ struct mlir_program ...@@ -614,21 +688,49 @@ struct mlir_program
} }
} }
void run_high_level_pipeline() MIGRAPHX_TIDY_CONST void run_high_level_pipeline()
{ {
mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_front{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddHighLevelPipeline(pm_front.get()); mlirMIGraphXAddHighLevelPipeline(pm_front.get());
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
if(mlirLogicalResultIsFailure(
mlirPassManagerRunOnOp(pm_front.get(), mlirModuleGetOperation(mmodule.get()))))
{
std::string error = "Invalid MLIR created: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
void run_backend_pipeline() MIGRAPHX_TIDY_CONST void run_backend_pipeline()
{ {
mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())}; mlir_pass_manager pm_back{mlirPassManagerCreate(ctx.get())};
mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str()); mlirMIGraphXAddBackendPipeline(pm_back.get(), target_arch.c_str());
mlirPassManagerRunOnOp(pm_back.get(), mlirModuleGetOperation(mmodule.get())); logger.clear();
const size_t trace = value_of(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
auto mod_op = mlirModuleGetOperation(mmodule.get());
if(trace >= 2)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
if(mlirLogicalResultIsFailure(mlirPassManagerRunOnOp(pm_back.get(), mod_op)))
{
std::string error = "MLIR backend compilation failed: " + logger.str();
if(enabled(MIGRAPHX_TRACE_MLIR{}))
{
std::cout << error << std::endl;
}
MIGRAPHX_THROW(error);
}
} }
code_object_op compile(const value& solution) MIGRAPHX_TIDY_CONST code_object_op compile(const value& solution)
{ {
// 1st pipeline to call // 1st pipeline to call
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -682,7 +784,7 @@ struct mlir_program ...@@ -682,7 +784,7 @@ struct mlir_program
MIGRAPHX_THROW("Failed setting tuning key: " + *str); MIGRAPHX_THROW("Failed setting tuning key: " + *str);
} }
tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive)
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
...@@ -702,7 +804,8 @@ struct mlir_program ...@@ -702,7 +804,8 @@ struct mlir_program
if(perf_key_bytes > perf_key.size()) if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) + MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long"); " bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes); tc.solutions.emplace_back(
std::string(perf_key.begin(), perf_key.begin() + perf_key_bytes));
} }
std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key; std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
size_t tuning_key_bytes = size_t tuning_key_bytes =
...@@ -809,6 +912,7 @@ struct mlir_program ...@@ -809,6 +912,7 @@ struct mlir_program
mlir_context ctx; mlir_context ctx;
MlirLocation location; MlirLocation location;
mlir_module mmodule; mlir_module mmodule;
mlir_logger logger;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch = ""; std::string target_arch = "";
......
...@@ -28,19 +28,20 @@ namespace migraphx { ...@@ -28,19 +28,20 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
double tolerance) const verify::expected<argument>& ref_arg,
verify::tolerance tols)
{ {
bool passed = true; bool passed = true;
visit_all(ref_arg, target_arg)([&](auto ref, auto target) { visit_all(ref_arg.data(), target_arg)([&](auto ref, auto target) {
double error; double rms_error;
passed = verify::verify_range(ref, target, tolerance, &error); passed =
verify::verify_range_with_tolerance(target, verify::expected{ref}, tols, &rms_error);
if(not passed) if(not passed)
{ {
// TODO: Check for nans // TODO: Check for nans
std::cout << "FAILED: " << name << std::endl; std::cout << "FAILED: " << name << std::endl;
std::cout << "error: " << error << std::endl; std::cout << "RMS Error: " << rms_error << std::endl;
if(ref.size() < 32) if(ref.size() < 32)
std::cout << "ref:" << ref << std::endl; std::cout << "ref:" << ref << std::endl;
if(target.size() < 32) if(target.size() < 32)
...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name, ...@@ -93,5 +94,16 @@ bool verify_args(const std::string& name,
return passed; return passed;
} }
bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance)
{
double rms_tol = 0.001;
target_arg.visit([&](auto ta) { rms_tol = verify::get_rms_tol(ta, tolerance); });
verify::tolerance tols{rms_tol};
return verify_args(name, target_arg, ref_arg, tols);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -158,6 +158,31 @@ TEST_CASE(two_transpose_gather) ...@@ -158,6 +158,31 @@ TEST_CASE(two_transpose_gather)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(standard_reshape_lazy)
{
migraphx::module m1;
{
auto data = m1.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m1.add_instruction(migraphx::make_op("add"), data, data);
auto r =
m1.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), add);
m1.add_return({r});
}
run_pass(m1);
migraphx::module m2;
{
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add);
auto r =
m2.add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r});
}
EXPECT(m1 == m2);
}
TEST_CASE(standard_reshape) TEST_CASE(standard_reshape)
{ {
migraphx::module m1; migraphx::module m1;
...@@ -173,8 +198,7 @@ TEST_CASE(standard_reshape) ...@@ -173,8 +198,7 @@ TEST_CASE(standard_reshape)
{ {
auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}}); auto data = m2.add_parameter("2x2", {migraphx::shape::float_type, {2, 3, 4, 5}});
auto add = m2.add_instruction(migraphx::make_op("add"), data, data); auto add = m2.add_instruction(migraphx::make_op("add"), data, data);
auto ca = m2.add_instruction(migraphx::make_op("contiguous"), add); auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), add);
auto r = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 12, 5}}}), ca);
m2.add_return({r}); m2.add_return({r});
} }
......
...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
run_prog(p, gpu_t, m, gpu_result); run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "make_precompile_op.hpp"
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_ops{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(layernorm_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4}};
auto create_program = [=](bool first_arg_layernorm) {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto layernorm_ins =
mm->add_instruction(make_precompile_op("gpu::prelayernorm"), add1, alloc_ins2);
std::vector<migraphx::instruction_ref> pw_inputs = {layernorm_ins, z};
if(not first_arg_layernorm)
{
pw_inputs = {z, layernorm_ins};
}
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", pw_inputs, single_pointwise("add"));
auto alloc_ins3 = mm->add_instruction(alloc);
pw_inputs.push_back(alloc_ins3);
auto add2 = mm->add_instruction(make_precompile_op("pointwise"), pw_inputs, {pw_add2});
mm->add_return({add2});
return p;
};
auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x, y}, single_pointwise("add"));
auto add1 =
mm->add_instruction(make_precompile_op("pointwise"), {x, y, alloc_ins}, {pw_add1});
auto alloc_ins2 = mm->add_instruction(alloc);
auto* pw_add2 =
create_pointwise_module(p, "main:pointwise1", {x, z}, single_pointwise("add"));
auto layernorm_ins = mm->add_instruction(
make_precompile_op("gpu::prelayernorm"), {add1, z, alloc_ins2}, {pw_add2});
mm->add_return({layernorm_ins});
return p;
};
{
migraphx::program p1 = create_program(true);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
{
migraphx::program p1 = create_program(false);
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment