Unverified Commit f6b08c2a authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into threads_register_target

parents d8f1ebbc 4188c38e
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -76,15 +77,25 @@ void verify_program(const std::string& name, ...@@ -76,15 +77,25 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto x = run_ref(p, inputs); auto ref_outs = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto target_outs = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size(); std::size_t output_num = ref_outs.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
verify_args(name, x[i], y[i], tolerance); if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
{
std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
<< target_outs[i].get_shape() << "}" << std::endl;
}
else
{
verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
}
} }
} }
...@@ -92,7 +103,7 @@ void verify_instructions(const program& prog, ...@@ -92,7 +103,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
double tolerance) verify::tolerance tols)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -123,8 +134,7 @@ void verify_instructions(const program& prog, ...@@ -123,8 +134,7 @@ void verify_instructions(const program& prog,
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program( verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
} }
catch(...) catch(...)
{ {
...@@ -140,14 +150,22 @@ void verify_reduced(program p, ...@@ -140,14 +150,22 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); try
{
verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
}
catch(const std::exception& e)
{
std::cout << "FAILED: " << n << std::endl;
std::cout << "Exception: " << e.what() << std::endl;
}
} }
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
...@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p, ...@@ -155,14 +173,20 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl; std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 1; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); auto last = std::prev(mm->end(), i + 1);
if(contains({"@literal", "@param"}, last->name()))
{
std::cout << "Skip: " << i << std::endl;
continue;
}
verify_reduced(p, i, t, options, quantize, inputs, tols);
} }
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "precision.hpp" #include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
...@@ -37,18 +38,18 @@ void verify_program(const std::string& name, ...@@ -37,18 +38,18 @@ void verify_program(const std::string& name,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction ...@@ -81,6 +81,7 @@ struct MIGRAPHX_EXPORT instruction
const std::vector<module_ref>& module_inputs() const; const std::vector<module_ref>& module_inputs() const;
/// Where this instruction is used as an input to another instruction
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y); friend bool operator==(const instruction& x, const instruction& y);
......
...@@ -49,17 +49,22 @@ struct allocate ...@@ -49,17 +49,22 @@ struct allocate
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
migraphx::check_shapes{inputs, *this, true}.has(0, 1);
// check if shape attribute is not default
if(s != shape()) if(s != shape())
{ {
if(inputs.size() == 1)
{
migraphx::check_shapes{inputs, *this, false}.only_dims(1);
}
else
{
migraphx::check_shapes{inputs, *this, false}.has(0);
}
return s; return s;
} }
else else
{ {
migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
const auto& out_dims = inputs.at(0); const auto& out_dims = inputs.at(0);
assert(not out_dims.dynamic());
assert(out_dims.ndim() == 1);
std::size_t max_val = std::numeric_limits<std::size_t>::max(); std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -45,8 +46,6 @@ struct reshape ...@@ -45,8 +46,6 @@ struct reshape
return pack(f(self.dims, "dims")); return pack(f(self.dims, "dims"));
} }
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape dyn_compute_shape(shape s0) const shape dyn_compute_shape(shape s0) const
...@@ -110,27 +109,9 @@ struct reshape ...@@ -110,27 +109,9 @@ struct reshape
return it; return it;
} }
template <class DimIterator, class StrideIterator> // This will attempt to alias the dimensions of the input shape to the lens of
static auto can_strides_merge(DimIterator dim_start, // `rdims`. Unlike reshape_lazy though we can modify memory layout with copies and this
DimIterator dim_last, // can remove previous nullopts that were sent back for the alias case
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will reshape the dimesions of the input shape to use the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims) static optional<shape> reshape_dims(const shape& input, const std::vector<std::size_t>& rdims)
{ {
if(input.standard()) if(input.standard())
...@@ -155,13 +136,8 @@ struct reshape ...@@ -155,13 +136,8 @@ struct reshape
{ {
auto start = idims.begin() + i; auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim); auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((i + n) <= istrides.size()); assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n; i += n;
rstrides.push_back(istrides[i]); rstrides.push_back(istrides[i]);
} }
...@@ -170,8 +146,7 @@ struct reshape ...@@ -170,8 +146,7 @@ struct reshape
{ {
auto start = rdims.begin() + i; auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim); auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start; auto n = it - start;
assert((r + n) <= rdims.size()); assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim; auto stride = istrides[i] * idim;
...@@ -191,15 +166,11 @@ struct reshape ...@@ -191,15 +166,11 @@ struct reshape
auto stride = rstrides.back(); auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end())) for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{ {
if(d != 1) (void)d;
return nullopt;
rstrides.push_back(stride); rstrides.push_back(stride);
} }
} }
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides}; return shape{input.type(), rdims, rstrides};
} }
...@@ -233,25 +204,24 @@ struct reshape ...@@ -233,25 +204,24 @@ struct reshape
} }
auto s = reshape_dims(inputs.front(), rdims); auto s = reshape_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("Reshape on axis that is not packed.");
if(s->elements() != inputs.front().elements()) if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " + MIGRAPHX_THROW("reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(s->elements()) + " elements whereas the input has " + std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements())); std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s; return *s;
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs[0];
auto s0 = inputs.front();
if(s0.dynamic()) if(s0.dynamic())
{ {
return dyn_compute_shape(s0); return dyn_compute_shape(s0);
...@@ -264,10 +234,14 @@ struct reshape ...@@ -264,10 +234,14 @@ struct reshape
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(dyn_out.computed_shape); assert(dyn_out.computed_shape.standard());
} argument result{dyn_out.computed_shape};
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return result;
}
}; };
} // namespace op } // namespace op
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#define MIGRAPHX_GUARD_OPERATORS_RESHAPE_LAZY_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/optional.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reshape_lazy
{
std::vector<int64_t> dims;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.dims, "dims"));
}
value attributes() const { return {{"require_std_shape", true}}; }
std::string name() const { return "reshape_lazy"; }
shape dyn_compute_shape(shape s0) const
{
auto dyn_dims = s0.dyn_dims();
auto num_not_fixed = std::count_if(
dyn_dims.cbegin(), dyn_dims.cend(), [](auto dd) { return not dd.is_fixed(); });
if(num_not_fixed != 1)
{
MIGRAPHX_THROW("reshape_lazy: Only supports one non-fixed dynamic_dimension");
}
// track number of fixed elements in input and output
std::size_t num_dims_ele = 1;
std::size_t num_dd_ele = 1;
for(std::size_t i = 0; i < dyn_dims.size(); ++i)
{
if(dyn_dims[i].is_fixed())
{
num_dims_ele *= dims[i];
num_dd_ele *= dyn_dims[i].min;
}
else
{
if(dims[i] != 0 and dims[i] != -1)
{
MIGRAPHX_THROW(
"reshape_lazy: Non-fixed dynamic_dimension doesn't match with 0 or -1 "
"output dimension");
}
}
}
if(num_dims_ele != num_dd_ele)
{
MIGRAPHX_THROW("reshape_lazy: Number of fixed elements must match. Input: " +
std::to_string(num_dd_ele) + " Output: " + std::to_string(num_dims_ele));
}
// construct output dynamic shape from dims attribute
std::vector<shape::dynamic_dimension> output_dyn_dims(dims.size());
std::transform(dims.cbegin(),
dims.cend(),
dyn_dims.cbegin(),
output_dyn_dims.begin(),
[](std::size_t dim, auto dyn_dim) {
if(not dyn_dim.is_fixed())
return dyn_dim;
return shape::dynamic_dimension{dim, dim};
});
return {s0.type(), output_dyn_dims};
}
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x >= dim;
});
if(x != dim)
return start;
return it;
}
template <class DimIterator, class StrideIterator>
static auto can_strides_merge(DimIterator dim_start,
DimIterator dim_last,
StrideIterator stride_start,
StrideIterator stride_last)
{
assert(std::distance(dim_start, dim_last) == std::distance(stride_start, stride_last));
auto cstride = *std::prev(stride_last);
return std::equal(std::make_reverse_iterator(dim_last),
std::make_reverse_iterator(dim_start + 1),
std::make_reverse_iterator(stride_last - 1),
std::make_reverse_iterator(stride_start),
[&](auto dim, auto stride) {
cstride *= dim;
return stride == cstride;
});
}
// This will attempt to alias the dimensions of the input shape to the lens of
// `rdims`. If this can't be done without changing memory layout then it
// will return nullopt
static optional<shape> reshape_lazy_dims(const shape& input,
const std::vector<std::size_t>& rdims)
{
if(input.standard())
return shape{input.type(), rdims};
const auto& idims = input.lens();
const auto& istrides = input.strides();
std::vector<std::size_t> rstrides;
std::size_t i = 0;
std::size_t r = 0;
while(i < idims.size() and r < rdims.size())
{
auto idim = idims[i];
auto rdim = rdims[r];
if(rdim == idim)
{
rstrides.push_back(istrides[i]);
}
// squeeze
else if(rdim > idim)
{
auto start = idims.begin() + i;
auto it = compute_end_dim(start, idims.end(), rdim);
if(it == start)
return nullopt;
auto n = it - start;
assert((i + n) <= istrides.size());
if(not can_strides_merge(
start, it + 1, istrides.begin() + i, istrides.begin() + i + n + 1))
return nullopt;
i += n;
rstrides.push_back(istrides[i]);
}
// unsqueeze
else // if(rdim < idim)
{
auto start = rdims.begin() + i;
auto it = compute_end_dim(start, rdims.end(), idim);
if(it == start)
return nullopt;
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {
stride /= dim;
rstrides.push_back(stride);
});
r += n;
}
i++;
r++;
}
// Handle trailing 1s
if(rstrides.size() < rdims.size() and not rstrides.empty())
{
auto stride = rstrides.back();
for(auto d : range(rdims.begin() + rstrides.size(), rdims.end()))
{
if(d != 1)
return nullopt;
rstrides.push_back(stride);
}
}
if(rdims.size() != rstrides.size())
return nullopt;
return shape{input.type(), rdims, rstrides};
}
shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
rdims[i] = idims[i];
// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == -1)
rdims[i] = 1;
}
if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == -1)
rdims[i] = missing_dim;
}
}
auto s = reshape_lazy_dims(inputs.front(), rdims);
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");
if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));
assert(s->bytes() == inputs.front().bytes());
return *s;
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim");
auto s0 = inputs[0];
if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
return args[0].reshape(dyn_out.computed_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -124,7 +124,7 @@ struct roialign ...@@ -124,7 +124,7 @@ struct roialign
{ {
xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] +
(i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii];
xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii];
if(xy[ii] < -1.0 or xy[ii] > dims[ii]) if(xy[ii] < -1.0 or xy[ii] > dims[ii])
{ {
results[index] = pos_weight{}; results[index] = pos_weight{};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_DYN_OPS_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Convert dynamic ops to their static version if possible.
* Should be run after the split_single_dyn_dims pass.
*/
struct MIGRAPHX_EXPORT simplify_dyn_ops
{
std::string name() const { return "simplify_dyn_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -29,10 +29,13 @@ ...@@ -29,10 +29,13 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <assert.h>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/env.hpp>
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE)
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify { namespace verify {
...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2) ...@@ -187,16 +190,103 @@ double rms_range(const R1& r1, const R2& r2)
return std::numeric_limits<range_value<R1>>::max(); return std::numeric_limits<range_value<R1>>::max();
} }
template <class R>
double get_rms_tol(const R&, std::size_t tolerance = 80)
{
double threshold = std::numeric_limits<range_value<R>>::epsilon() * tolerance;
return threshold;
}
/*
C++ doesn't support named arguments, this is just wrapper that helps distinguish between actual
results v/s expected results arguments.
*/
template <class T>
struct expected
{
expected() = default;
explicit expected(const T& input) : x(&input) {}
const T& data() const
{
assert(x != nullptr);
return *x;
}
private:
const T* x = nullptr;
};
// deduction guide for templated expected class
template <class T>
expected(const T&) -> expected<T>;
struct tolerance
{
double rms_tol = 0.001;
double atol = 0.001;
double rtol = 0.001;
};
/*
MIGraphX implementation of numpy's np.allclose() which checks if elementwise absolute diff is within
tolerance using this formula: abs(a - b) < atol + rtol(abs(b))
*/
template <class R1, class R2>
bool allclose(const R1& r1, const R2& r2, tolerance tols)
{
std::size_t n = range_distance(r1);
if(n == range_distance(r2))
{
auto idx = mismatch_idx(r1, r2, [&](auto x, auto y) {
return abs_diff(double(x), double(y)) < tols.atol + tols.rtol * std::abs(double(y));
});
return idx >= range_distance(r1);
}
return false;
}
template <class R1, class R2> template <class R1, class R2>
bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out_error = nullptr) bool verify_rms_range(const R1& r1,
const R2& r2,
std::size_t tolerance = 80,
double* out_rms_error = nullptr)
{ {
double threshold = std::numeric_limits<range_value<R1>>::epsilon() * tolerance; double threshold = get_rms_tol(r1, tolerance);
auto error = rms_range(r1, r2); auto error = rms_range(r1, r2);
if(out_error != nullptr) if(out_rms_error != nullptr)
*out_error = error; *out_rms_error = error;
return error <= threshold; return error <= threshold;
} }
template <class R1, class R2>
bool verify_range_with_tolerance(const R1& r1,
const expected<R2>& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
auto rms_error = rms_range(r1, r2.data());
// disable ewise_verify by default for now, it requires lot of tests to be fixed
bool ewise_verify = true;
if(enabled(MIGRAPHX_VERIFY_ENABLE_ALLCLOSE{}))
{
ewise_verify = allclose(r1, r2.data(), tols);
}
if(out_rms_error != nullptr)
*out_rms_error = rms_error;
return rms_error <= tols.rms_tol and ewise_verify;
}
// expected argument should be passed as second, but if it is passed as the first by mistake then
// flip the order
template <class R1, class R2>
bool verify_range_with_tolerance(const expected<R1>& r1,
const R2& r2,
tolerance tols = tolerance{},
double* out_rms_error = nullptr)
{
return verify_rms_range(r2, r1, tols, out_rms_error);
}
} // namespace verify } // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -31,11 +31,15 @@ ...@@ -31,11 +31,15 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT MIGRAPHX_EXPORT bool verify_args(const std::string& name,
bool verify_args(const std::string& name, const argument& target_arg,
const argument& ref_arg, const verify::expected<argument>& ref_arg,
const argument& target_arg, verify::tolerance);
double tolerance = 80);
MIGRAPHX_EXPORT bool verify_args_with_tolerance(const std::string& name,
const argument& target_arg,
const verify::expected<argument>& ref_arg,
std::size_t tolerance = 80);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const ...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
mpm.run_pass(simplify_reshapes{}); // loop to further optimize after initial transformations
mpm.run_pass(simplify_algebra{}); for(int j = 0; j < 2; j++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
}
mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{}); mpm.run_pass(propagate_constant{});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propagate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
......
...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const ...@@ -43,9 +43,7 @@ void rewrite_pooling::apply(module& m) const
continue; continue;
if(ins->inputs().empty()) if(ins->inputs().empty())
continue; continue;
auto&& s = ins->inputs().front()->get_shape(); auto&& s = ins->inputs().front()->get_shape();
if(not s.standard())
continue;
auto&& op = any_cast<op::pooling>(ins->get_operator()); auto&& op = any_cast<op::pooling>(ins->get_operator());
if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; })) if(not std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
continue; continue;
...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const ...@@ -54,27 +52,18 @@ void rewrite_pooling::apply(module& m) const
auto lens = s.lens(); auto lens = s.lens();
if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end())) if(not std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
continue; continue;
std::int64_t n = s.lens()[0]; std::vector<std::int64_t> axes(lens.size() - 2);
std::int64_t c = s.lens()[1]; std::iota(axes.begin(), axes.end(), 2);
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_mean", {{"axes", axes}}), ins->inputs());
} }
// max pooling // max pooling
else else
{ {
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); m.replace_instruction(ins, make_op("reduce_max", {{"axes", axes}}), ins->inputs());
} }
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Convert 2 input static shape broadcast/multibroadcast into 1 input version.
* Some compiler passes (ex. simplify_algebra) only support the 1 input versions
* of the broadcasting operators.
*/
struct find_static_2in_broadcasts
{
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
/**
* Simplify slice with variable `starts` and `ends` to the constant version if
* the `input_starts` and `input_ends` inputs are constant.
*/
struct find_const_3in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(3),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
if(not starts_arg.empty() and not ends_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
auto slice_val = ins->get_operator().to_value();
auto axes_vec = slice_val.at("axes").to_vector<int64_t>();
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
/**
* Simplify slice with variable `starts`, `ends`, and `input_axes` to the constant version if
* the `input_starts`, `input_ends`, and `input_axes` inputs are constant.
*/
struct find_const_4in_slice
{
auto matcher() const
{
return match::name("slice")(match::nargs(4),
match::arg(1)(match::is_constant()),
match::arg(2)(match::is_constant()),
match::arg(3)(match::is_constant()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
argument starts_arg = inputs.at(1)->eval();
argument ends_arg = inputs.at(2)->eval();
argument axes_arg = inputs.at(3)->eval();
if(not starts_arg.empty() and not ends_arg.empty() and not axes_arg.empty())
{
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
starts_arg.visit([&](auto output) { starts_vec.assign(output.begin(), output.end()); });
ends_arg.visit([&](auto output) { ends_vec.assign(output.begin(), output.end()); });
axes_arg.visit([&](auto output) { axes_vec.assign(output.begin(), output.end()); });
m.replace_instruction(
ins,
make_op("slice", {{"starts", starts_vec}, {"ends", ends_vec}, {"axes", axes_vec}}),
inputs.at(0));
}
}
};
void simplify_dyn_ops::apply(module& m) const
{
match::find_matches(
m, find_static_2in_broadcasts{}, find_const_3in_slice{}, find_const_4in_slice{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -122,6 +122,11 @@ struct find_nop_reshapes ...@@ -122,6 +122,11 @@ struct find_nop_reshapes
reshapes.insert("pad"); reshapes.insert("pad");
reshapes.insert("slice"); reshapes.insert("slice");
reshapes.insert("transpose"); reshapes.insert("transpose");
reshapes.insert("reduce_mean");
reshapes.insert("reduce_max");
reshapes.insert("reduce_min");
reshapes.insert("reduce_sum");
reshapes.insert("reduce_prod");
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
...@@ -627,6 +632,30 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +632,30 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -799,6 +828,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -799,6 +828,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -68,37 +68,6 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. Probably won't work for `if` * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -135,7 +104,6 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -23,6 +23,10 @@ ...@@ -23,6 +23,10 @@
# #################################################################################### # ####################################################################################
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
if(NOT GPU_TARGETS)
message(FATAL_ERROR "HIP package is broken and has no GPU_TARGETS, please pass -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*') to cmake to build for your gpu.")
endif()
find_package(miopen) find_package(miopen)
# rocblas # rocblas
...@@ -194,7 +198,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp ...@@ -194,7 +198,7 @@ register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION}) rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
set(MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL "") set(MIGRAPHX_ENABLE_MLIR ON CACHE BOOL "")
if(MIGRAPHX_ENABLE_MLIR) if(MIGRAPHX_ENABLE_MLIR)
# Find package rocMLIR # Find package rocMLIR
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <deque>
#ifdef MIGRAPHX_USE_HIPRTC #ifdef MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
...@@ -92,7 +93,7 @@ struct hiprtc_program ...@@ -92,7 +93,7 @@ struct hiprtc_program
{ {
struct string_array struct string_array
{ {
std::vector<std::string> strings{}; std::deque<std::string> strings{};
std::vector<const char*> c_strs{}; std::vector<const char*> c_strs{};
string_array() {} string_array() {}
...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr ...@@ -209,7 +210,6 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<hiprtc_sr
options.push_back("-Wno-gnu-line-marker"); options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast"); options.push_back("-Wno-old-style-cast");
} }
if(enabled(MIGRAPHX_GPU_DEBUG{})) if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG"); options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) { if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment