Commit 711ff872 authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge remote-tracking branch 'origin/develop' into fp8_rocblas

parents 8bfb5e56 200c7038
...@@ -67,7 +67,7 @@ The following is a list of prerequisites for building MIGraphX. ...@@ -67,7 +67,7 @@ The following is a list of prerequisites for building MIGraphX.
3. Build MIGraphX source code: 3. Build MIGraphX source code:
```bash ```bash
rbuild build -d depend -B build rbuild build -d depend -B build -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*')
``` ```
Once completed, all prerequisites are in the `depend` folder and MIGraphX is in the `build` directory. Once completed, all prerequisites are in the `depend` folder and MIGraphX is in the `build` directory.
...@@ -106,7 +106,7 @@ the folder to `PATH`, or add the option `--prefix /usr/local` in the pip3 comman ...@@ -106,7 +106,7 @@ the folder to `PATH`, or add the option `--prefix /usr/local` in the pip3 comman
3. Configure CMake. If the prerequisites are installed at the default location `/usr/local`, use: 3. Configure CMake. If the prerequisites are installed at the default location `/usr/local`, use:
```bash ```bash
CXX=/opt/rocm/llvm/bin/clang++ cmake .. CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DGPU_TARGETS=$(/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*')
``` ```
Otherwise, you need to set `-DCMAKE_PREFIX_PATH=$your_loc` to configure CMake. Otherwise, you need to set `-DCMAKE_PREFIX_PATH=$your_loc` to configure CMake.
......
Contributor Guide Contributor Guide
=============== =================
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
:caption: Contents: :caption: Contents:
dev_intro dev/dev_intro
dev/data dev/data
dev/operators dev/operators
dev/program dev/program
......
MIGraphX Fundamentals Developer Introduction
====================== ======================
MIGraphX provides an optimized execution engine for deep learning neural networks. MIGraphX provides an optimized execution engine for deep learning neural networks.
......
MIGraphX Driver MIGraphX Driver
=============== ===============
The MIGraphX driver is a tool that allows you to utilize many of the core functions of MIGraphX without having to write your own program. It can read, compile, run, and test the performance of a model with randomized data.
read read
---- ----
...@@ -17,6 +19,7 @@ compile ...@@ -17,6 +19,7 @@ compile
Compiles and prints input graph. Compiles and prints input graph.
.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
run run
...@@ -26,6 +29,7 @@ run ...@@ -26,6 +29,7 @@ run
Loads and prints input graph. Loads and prints input graph.
.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
perf perf
...@@ -35,6 +39,7 @@ perf ...@@ -35,6 +39,7 @@ perf
Compiles and runs input graph then prints performance report. Compiles and runs input graph then prints performance report.
.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
.. option:: --iterations, -n [unsigned int] .. option:: --iterations, -n [unsigned int]
...@@ -48,6 +53,7 @@ verify ...@@ -48,6 +53,7 @@ verify
Runs reference and CPU or GPU implementations and checks outputs for consistency. Runs reference and CPU or GPU implementations and checks outputs for consistency.
.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst .. include:: ./driver/compile.rst
.. option:: --rms-tol [double] .. option:: --rms-tol [double]
...@@ -71,7 +77,7 @@ Verify each instruction ...@@ -71,7 +77,7 @@ Verify each instruction
Reduce program and verify Reduce program and verify
roctx roctx
---- -----
.. program:: migraphx-driver roctx .. program:: migraphx-driver roctx
...@@ -86,4 +92,5 @@ An example command line combined with rocprof for tracing purposes is given belo ...@@ -86,4 +92,5 @@ An example command line combined with rocprof for tracing purposes is given belo
After `rocprof` is run, the output directory will contain trace information for HIP, HCC and ROCTX in seperate `.txt` files. After `rocprof` is run, the output directory will contain trace information for HIP, HCC and ROCTX in seperate `.txt` files.
To understand the interactions between API calls, it is recommended to utilize `roctx.py` helper script as desribed in :ref:`dev/tools:rocTX` section. To understand the interactions between API calls, it is recommended to utilize `roctx.py` helper script as desribed in :ref:`dev/tools:rocTX` section.
.. include:: ./driver/compile.rst .. include:: ./driver/read.rst
\ No newline at end of file .. include:: ./driver/compile.rst
.. include:: ./driver/read.rst
.. option:: --fill0 [std::vector<std::string>] .. option:: --fill0 [std::vector<std::string>]
Fill parameter with 0s Fill parameter with 0s
......
...@@ -46,11 +46,11 @@ Trim instructions from the end (Default: 0) ...@@ -46,11 +46,11 @@ Trim instructions from the end (Default: 0)
Dim of a parameter (format: "@name d1 d2 dn") Dim of a parameter (format: "@name d1 d2 dn")
.. options:: --dyn-input-dim [std::vector<std::string>] .. option:: --dyn-input-dim [std::vector<std::string>]
Set dynamic dimensions of a parameter using JSON formatting (format "@name" "dynamic_dimension_json") Set dynamic dimensions of a parameter using JSON formatting (format "@name" "dynamic_dimension_json")
.. options:: --default-dyn-dim .. option:: --default-dyn-dim
Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
......
...@@ -95,7 +95,7 @@ shape ...@@ -95,7 +95,7 @@ shape
:rtype: bool :rtype: bool
dynamic_dimension dynamic_dimension
-------- -----------------
.. py:class:: dynamic_dimension(min, max, optimals) .. py:class:: dynamic_dimension(min, max, optimals)
......
...@@ -175,6 +175,7 @@ register_migraphx_ops( ...@@ -175,6 +175,7 @@ register_migraphx_ops(
mul mul
multibroadcast multibroadcast
multinomial multinomial
nearbyint
neg neg
nonmaxsuppression nonmaxsuppression
nonzero nonzero
...@@ -205,7 +206,6 @@ register_migraphx_ops( ...@@ -205,7 +206,6 @@ register_migraphx_ops(
rnn_last_hs_output rnn_last_hs_output
rnn_var_sl_last_output rnn_var_sl_last_output
roialign roialign
round
rsqrt rsqrt
run_on_target run_on_target
scalar scalar
......
...@@ -63,7 +63,7 @@ inline std::string get_version() ...@@ -63,7 +63,7 @@ inline std::string get_version()
{ {
return "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + "." + return "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + "." +
std::to_string(MIGRAPHX_VERSION_MINOR) + "." + std::to_string(MIGRAPHX_VERSION_PATCH) + std::to_string(MIGRAPHX_VERSION_MINOR) + "." + std::to_string(MIGRAPHX_VERSION_PATCH) +
"." + MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK); "." MIGRAPHX_VERSION_TWEAK;
} }
struct loader struct loader
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -21,24 +21,28 @@ ...@@ -21,24 +21,28 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_OPERATORS_ROUND_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_NEARBYINT_HPP
#define MIGRAPHX_GUARD_OPERATORS_ROUND_HPP #define MIGRAPHX_GUARD_OPERATORS_NEARBYINT_HPP
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <fenv.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct nearbyint : unary<nearbyint>
struct round : unary<round>
{ {
auto apply() const auto apply() const
{ {
return [](auto x) { return std::round(x); }; return [](auto x) {
auto rounding_mode = fegetround();
fesetround(FE_TONEAREST);
return std::nearbyint(x);
fesetround(rounding_mode);
};
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -30,11 +30,11 @@ ...@@ -30,11 +30,11 @@
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <cmath> #include <cmath>
#include <fenv.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct quantizelinear struct quantizelinear
{ {
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
...@@ -71,26 +71,26 @@ struct quantizelinear ...@@ -71,26 +71,26 @@ struct quantizelinear
{ {
y_zero_point = args.at(2); y_zero_point = args.at(2);
} }
argument result{output_shape}; argument result{output_shape};
auto rounding_mode = fegetround();
fesetround(FE_TONEAREST);
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) { visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
visit_all(x, y_scale)([&](auto input, auto scales) { visit_all(x, y_scale)([&](auto input, auto scales) {
using quant_type = typename decltype(output)::value_type; using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) + int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<int64_t>(max_value), quantized));
}); });
}); });
}); });
fesetround(rounding_mode);
return result; return result;
} }
}; };
} // namespace op } // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -84,6 +84,7 @@ ...@@ -84,6 +84,7 @@
#include <migraphx/op/mod.hpp> #include <migraphx/op/mod.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/nearbyint.hpp>
#include <migraphx/op/neg.hpp> #include <migraphx/op/neg.hpp>
#include <migraphx/op/nonmaxsuppression.hpp> #include <migraphx/op/nonmaxsuppression.hpp>
#include <migraphx/op/nonzero.hpp> #include <migraphx/op/nonzero.hpp>
...@@ -110,7 +111,6 @@ ...@@ -110,7 +111,6 @@
#include <migraphx/op/rnn_variable_seq_lens.hpp> #include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/roialign.hpp> #include <migraphx/op/roialign.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter_add.hpp> #include <migraphx/op/scatter_add.hpp>
......
...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"}, {"Neg", "neg"},
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "nearbyint"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
...@@ -144,16 +144,15 @@ struct parse_slice : op_parser<parse_slice> ...@@ -144,16 +144,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes; sd.op.axes = axes;
} }
if(not sd.steps.empty()) if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{ {
if(sd.op.starts.empty() or sd.op.ends.empty()) if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported"); MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty()) if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported"); MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
} }
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op // If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size())) for(auto i : range(sd.steps.size()))
{ {
......
...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -47,7 +47,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x); ins, make_op("convert", {{"target_type", y_scale->get_shape().type()}}), x);
} }
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale); auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div); auto add_zero_point = m.insert_instruction(ins, make_op("nearbyint"), div);
if(ins->inputs().size() == 3) if(ins->inputs().size() == 3)
{ {
......
...@@ -120,6 +120,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor) ...@@ -120,6 +120,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(isinf, ::isinf) MIGRAPHX_DEVICE_MATH(isinf, ::isinf)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder) MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
...@@ -169,6 +170,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan) ...@@ -169,6 +170,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh) MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(nearbyint, ::nearbyint)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder) MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
...@@ -283,6 +285,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan) ...@@ -283,6 +285,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(nearbyint)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder) MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
......
...@@ -64,7 +64,7 @@ TEST_CASE(mul_literal_round_test) ...@@ -64,7 +64,7 @@ TEST_CASE(mul_literal_round_test)
auto l1 = mm->add_literal(1 / 0.00787402f); auto l1 = mm->add_literal(1 / 0.00787402f);
auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1); auto mul = mm->add_instruction(migraphx::make_op("mul"), l0, l1);
auto round = mm->add_instruction(migraphx::make_op("round"), mul); auto round = mm->add_instruction(migraphx::make_op("nearbyint"), mul);
mm->add_return({round}); mm->add_return({round});
......
...@@ -7087,6 +7087,16 @@ def roialign_test(): ...@@ -7087,6 +7087,16 @@ def roialign_test():
return ([node], [x, roi, bi], [y]) return ([node], [x, roi, bi], [y])
@onnx_test()
def round_half_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [4, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [4, 4])
node = onnx.helper.make_node('Round', inputs=['x'], outputs=['y'])
return ([node], [x], [y])
@onnx_test() @onnx_test()
def scatter_add_test(): def scatter_add_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6]) x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
...@@ -8006,6 +8016,32 @@ def slice_var_input_dyn1(): ...@@ -8006,6 +8016,32 @@ def slice_var_input_dyn1():
return ([node], [data, starts, ends, axes], [output]) return ([node], [data, starts, ends, axes], [output])
@onnx_test()
def slice_var_input_default_steps():
step = np.array([1, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT64,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [None, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['data', 'starts', 'ends', 'axes', 'arg_step'],
outputs=['output'])
return ([arg_step, node], [data, starts, ends, axes], [output])
@onnx_test() @onnx_test()
def slice_var_input_steps_error(): def slice_var_input_steps_error():
step = np.array([2, 1]) step = np.array([2, 1])
...@@ -8019,9 +8055,9 @@ def slice_var_input_steps_error(): ...@@ -8019,9 +8055,9 @@ def slice_var_input_steps_error():
value=step_tensor) value=step_tensor)
data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2]) data = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 2])
starts = helper.make_tensor_value_info('starts', TensorProto.FLOAT, [2]) starts = helper.make_tensor_value_info('starts', TensorProto.INT64, [2])
ends = helper.make_tensor_value_info('ends', TensorProto.FLOAT, [2]) ends = helper.make_tensor_value_info('ends', TensorProto.INT64, [2])
axes = helper.make_tensor_value_info('axes', TensorProto.FLOAT, [2]) axes = helper.make_tensor_value_info('axes', TensorProto.INT64, [2])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2]) output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2])
node = onnx.helper.make_node( node = onnx.helper.make_node(
......
...@@ -5788,9 +5788,9 @@ TEST_CASE(quantizelinear_test) ...@@ -5788,9 +5788,9 @@ TEST_CASE(quantizelinear_test)
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto nearbyint = mm->add_instruction(migraphx::make_op("nearbyint"), div);
auto s = round->get_shape(); auto s = nearbyint->get_shape();
auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255); auto clip = insert_quantizelinear_clip(*mm, div, nearbyint, s, 0, 255);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...@@ -5813,9 +5813,9 @@ TEST_CASE(quantizelinear_int32_test) ...@@ -5813,9 +5813,9 @@ TEST_CASE(quantizelinear_int32_test)
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0); l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto nearbyint = mm->add_instruction(migraphx::make_op("nearbyint"), div);
auto s = round->get_shape(); auto s = nearbyint->get_shape();
auto clip = insert_quantizelinear_clip(*mm, div, round, s, 0, 255); auto clip = insert_quantizelinear_clip(*mm, div, nearbyint, s, 0, 255);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...@@ -5835,7 +5835,7 @@ TEST_CASE(quantizelinear_zero_point_test) ...@@ -5835,7 +5835,7 @@ TEST_CASE(quantizelinear_zero_point_test)
auto l1_mbcast = auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("nearbyint"), div);
auto l2_mbcast = auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction( l2_mbcast = mm->add_instruction(
...@@ -5868,7 +5868,7 @@ migraphx::program make_quantizelinear_axis_prog() ...@@ -5868,7 +5868,7 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast); auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast);
auto round = mm->add_instruction(migraphx::make_op("round"), div); auto round = mm->add_instruction(migraphx::make_op("nearbyint"), div);
auto l2_bcast = mm->add_instruction( auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction( l2_bcast = mm->add_instruction(
...@@ -6997,7 +6997,7 @@ TEST_CASE(round_test) ...@@ -6997,7 +6997,7 @@ TEST_CASE(round_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}}); auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
mm->add_instruction(migraphx::make_op("round"), input); mm->add_instruction(migraphx::make_op("nearbyint"), input);
auto prog = optimize_onnx("round_test.onnx"); auto prog = optimize_onnx("round_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -7653,6 +7653,25 @@ TEST_CASE(slice_var_input_dyn1) ...@@ -7653,6 +7653,25 @@ TEST_CASE(slice_var_input_dyn1)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_var_input_default_steps)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int64_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int64_type, {2}});
auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {2}});
mm->add_literal({{migraphx::shape::int64_type, {2}}, {1, 1}});
auto ret = mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("slice_var_input_default_steps.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_steps_error) TEST_CASE(slice_var_input_steps_error)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); }));
......
reshape_variable_input_test0:q

0
12"Reshapereshape_variable_input_test0Z
0



Z
1

b
2


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment