Commit ac04f3cc authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual_merge

parents d39c3343 d8011adf
......@@ -46,11 +46,11 @@ Trim instructions from the end (Default: 0)
Dim of a parameter (format: "@name d1 d2 dn")
.. options:: --dyn-input-dim [std::vector<std::string>]
.. option:: --dyn-input-dim [std::vector<std::string>]
Set dynamic dimensions of a parameter using JSON formatting (format "@name" "dynamic_dimension_json")
.. options:: --default-dyn-dim
.. option:: --default-dyn-dim
Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
......
......@@ -95,7 +95,7 @@ shape
:rtype: bool
dynamic_dimension
--------
-----------------
.. py:class:: dynamic_dimension(min, max, optimals)
......@@ -326,7 +326,7 @@ op
parse_onnx
----------
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10)
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10, limit_max_iterations=65535)
Load and parse an onnx file.
......@@ -337,7 +337,8 @@ parse_onnx
:param list[dynamic_dimension] map_dyn_input_dims: Explicitly specify the dynamic_dimensions of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs.
:param int max_loop_iterations: Maximum iteration number for the loop operator.
:param int max_loop_iterations: Maximum iteration number for the loop operator if trip count is not set.
:param int limit_max_iterations: Maximum iteration limit for the loop operator.
:rtype: program
parse_tf
......
......@@ -32,7 +32,7 @@
#define MIGRAPHX_MIOPEN_ASSERT(x) (assert((x) == miopenStatusSuccess))
#define MIGRAPHX_HIP_ASSERT(x) (assert((x) == hipSuccess))
inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, bool pack = false)
inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s)
{
miopenTensorDescriptor_t t;
MIGRAPHX_MIOPEN_ASSERT(miopenCreateTensorDescriptor(&t));
......@@ -49,23 +49,9 @@ inline miopenTensorDescriptor_t make_miopen_tensor(const migraphx::shape& s, boo
else if(s.type() == migraphx_shape_int32_type)
d = miopenInt32;
else if(s.type() == migraphx_shape_int8_type)
{
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
d = miopenInt8;
else
{
throw("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t, d, s_lens.size(), lens.data(), strides.data());
return t;
}
......
......@@ -55,7 +55,9 @@ See below for a comprehensive list of commands and option arguments, as well as
| --exhaustive-tune | Enable exhaustive search to find fastest kernel |
| --fp16 | Quantize for fp16 |
| --int8 | Quantize for int8 |
| --tolerance | Tolerance for errors |
| --rms-tol | Tolerance for the RMS error (Default: 0.001) |
| --atol | Tolerance for elementwise absolute difference (Default: 0.001) |
| --rtol | Tolerance for elementwise relative difference (Default: 0.001) |
| --per-instruction \| -i | Verify each instruction |
| --reduce \| -r | Reduce program and verify |
| --iterations \| -n | Number of iterations to run for perf report |
......@@ -147,9 +149,6 @@ gpu::gelu
gpu::gelu_new
gpu::gemm
gpu::greater
gpu::int8_conv_pack
gpu::int8_gemm_pack_a
gpu::int8_gemm_pack_b
gpu::layernorm
gpu::leaky_relu
gpu::less
......
......@@ -6,7 +6,7 @@ ARG PREFIX=/usr/local
RUN dpkg --add-architecture i386
# Add rocm repository
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.6/ focal main > /etc/apt/sources.list.d/rocm.list'
RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/5.7/ focal main > /etc/apt/sources.list.d/rocm.list'
# Install dependencies
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
......@@ -60,6 +60,3 @@ RUN pip3 install cmake==3.22.1
COPY ./tools/install_prereqs.sh /
RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh
# Install MLIR
ADD mlir-requirements.txt /mlir-requirements.txt
RUN cget -p /usr/local install -f /mlir-requirements.txt
......@@ -29,3 +29,12 @@ define =
CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache
MIGRAPHX_ENABLE_CPU=On
BUILD_DEV=On
[cibuild]
cxx = ${rocm_path}/llvm/bin/clang++
cc = ${rocm_path}/llvm/bin/clang
deps =
-f dev-requirements.txt
define =
CMAKE_C_COMPILER_LAUNCHER=${deps_dir}/bin/ccache
CMAKE_CXX_COMPILER_LAUNCHER=${deps_dir}/bin/ccache
......@@ -21,11 +21,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
google/protobuf@v3.11.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
google/protobuf@v3.19.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/json@v3.8.0
live-clones/blaze@v3.8 -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@5172ec5280f14974beee2acf1af1db3b2670244c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@13f6c2a69cfe80a575c6b241ec7353d1e953cb12 -DBUILD_FAT_LIBROCKCOMPILER=On
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -28,6 +28,7 @@ include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
include(RegisterOp)
include(CheckCXXLinkerFlag)
add_library(migraphx
adjust_allocation.cpp
......@@ -36,6 +37,7 @@ add_library(migraphx
argument.cpp
auto_contiguous.cpp
common.cpp
common_dims.cpp
compile_src.cpp
convert_to_json.cpp
cpp_generator.cpp
......@@ -95,6 +97,7 @@ add_library(migraphx
serialize.cpp
shape.cpp
simplify_algebra.cpp
simplify_dyn_ops.cpp
simplify_reshapes.cpp
split_single_dyn_dim.cpp
target.cpp
......@@ -141,6 +144,7 @@ register_migraphx_ops(
equal
erf
exp
fill
flatten
floor
fmod
......@@ -152,6 +156,7 @@ register_migraphx_ops(
identity
if_op
im2col
isinf
isnan
layout
leaky_relu
......@@ -171,6 +176,7 @@ register_migraphx_ops(
mul
multibroadcast
multinomial
nearbyint
neg
nonmaxsuppression
nonzero
......@@ -184,6 +190,8 @@ register_migraphx_ops(
quant_convolution
quant_dot
quantizelinear
random_uniform
random_seed
recip
reduce_max
reduce_mean
......@@ -192,13 +200,13 @@ register_migraphx_ops(
reduce_sum
relu
reshape
reshape_lazy
reverse
rnn
rnn_last_cell_output
rnn_last_hs_output
rnn_var_sl_last_output
roialign
round
rsqrt
run_on_target
scalar
......@@ -255,9 +263,8 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
target_link_libraries(migraphx PRIVATE nlohmann_json::nlohmann_json)
migraphx_generate_export_header(migraphx)
find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(SQLite3 REQUIRED)
target_link_libraries(migraphx PRIVATE SQLite::SQLite3)
find_package(msgpackc-cxx QUIET)
if(NOT msgpackc-cxx_FOUND)
......@@ -276,7 +283,9 @@ add_subdirectory(driver)
add_subdirectory(onnx)
add_subdirectory(tf)
if(MIGRAPHX_ENABLE_PYTHON)
add_subdirectory(py)
endif()
add_subdirectory(targets/ref)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
if(MIGRAPHX_ENABLE_CPU)
......
......@@ -32,6 +32,10 @@ migraphx_generate_export_header(migraphx_c DIRECTORY migraphx/api)
# bumped when binary compatibility is broken.
rocm_set_soversion(migraphx_c 3.0)
if(BUILD_TESTING)
target_compile_definitions(migraphx_c PRIVATE MIGRAPHX_BUILD_TESTING)
endif()
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx)
......
......@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
#endif
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if(disable_exception_catch)
{
f();
}
else
{
#endif
try
{
f();
......@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return migraphx_status_unknown_error;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return migraphx_status_success;
}
......@@ -156,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value;
}
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......@@ -1896,6 +1909,17 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_limit_loop_iterations(migraphx_onnx_options_t onnx_options, int64_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_limit_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
......
......@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
......@@ -513,6 +514,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_valu
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_limit_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options);
......
......@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
#if defined(_MSC_VER) && !defined(__clang__)
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
......@@ -1321,6 +1321,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
}
/// Set max iteration limit for the loop operator
void set_limit_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_limit_loop_iterations, this->get_handle_ptr(), value);
}
};
/// Parse an onnx file into a migraphx program
......
......@@ -349,6 +349,11 @@ def onnx_options(h):
api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)',
)
h.method(
'set_limit_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_limit_loop_iterations($@)',
)
@auto_handle()
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator>
static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
{
std::size_t x = 1;
auto it = std::find_if(start, last, [&](auto i) {
x *= i;
return x > dim;
});
if(x < dim)
return start;
return it;
}
template <class Range>
static auto elements(const Range& r)
{
return std::accumulate(r.begin(), r.end(), std::size_t{1}, std::multiplies<>{});
}
struct common_dim_state
{
common_dim_state(const std::vector<std::size_t>& pdims,
std::vector<std::vector<std::size_t>>& paxes_map)
: dims(&pdims), axes_map(&paxes_map), it(dims->begin())
{
}
const std::vector<std::size_t>* dims = nullptr;
std::vector<std::vector<std::size_t>>* axes_map = nullptr;
std::vector<std::size_t>::const_iterator it{};
std::size_t rem = 1;
std::size_t get() const { return *it / rem; }
bool is_end() const { return it == dims->end(); }
void next(std::size_t i = 1) { it += i; }
auto dims_for(std::size_t d) const
{
auto dim_end = compute_end_dim(it, dims->end(), d);
return range(it, dim_end);
}
void add_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
axes_map->push_back(std::move(axes));
}
void add_multi_axes(std::size_t naxes, std::size_t start) MIGRAPHX_TIDY_CONST
{
auto axes = compute_axes(naxes, start);
std::transform(axes.begin(),
axes.end(),
std::back_inserter(*axes_map),
[&](auto axis) -> std::vector<std::size_t> { return {axis}; });
}
std::vector<std::size_t> compute_axes(std::size_t naxes, std::size_t start) const
{
if(rem != 1)
{
assert(start > 0);
naxes++;
start--;
}
std::vector<std::size_t> axes(naxes);
std::iota(axes.begin(), axes.end(), start);
return axes;
}
};
static bool compute_common_dim(std::vector<std::size_t>& cd_dims,
common_dim_state& state1,
common_dim_state& state2)
{
assert(state1.get() <= state2.get());
auto d2 = state2.get();
auto dims = state1.dims_for(d2);
auto n = elements(dims);
auto naxes = distance(dims);
if(naxes == 0)
return false;
// If not divisible then we can't compute a common dim
if((d2 % n) != 0)
return false;
auto rem = d2 / n;
state1.add_multi_axes(naxes, cd_dims.size());
state2.add_axes(rem == 1 ? naxes : naxes + 1, cd_dims.size());
state1.rem = rem;
state2.rem = 1;
cd_dims.insert(cd_dims.end(), dims.begin(), dims.end());
if(state1.rem != 1)
cd_dims.push_back(state1.rem);
state1.next(distance(dims));
state2.next();
return true;
}
common_dims common_dims::compute(const std::vector<std::size_t>& dims1,
const std::vector<std::size_t>& dims2)
{
assert(elements(dims1) > 0);
assert(elements(dims1) == elements(dims2));
common_dims cd;
common_dim_state state1{dims1, cd.axes_map1};
common_dim_state state2{dims2, cd.axes_map2};
while(not state1.is_end() and not state2.is_end())
{
auto d1 = state1.get();
auto d2 = state2.get();
if(d1 <= d2)
{
if(not compute_common_dim(cd.dims, state1, state2))
return {};
}
else // if(d1 > d2)
{
if(not compute_common_dim(cd.dims, state2, state1))
return {};
}
}
assert(elements(dims1) == elements(cd.dims));
return cd;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path);
write_buffer(full_path.string(), src.content.first, src.len());
write_buffer(full_path.string(), src.content.data(), src.content.size());
if(src.path.extension().string() == ".cpp")
{
params += " " + src.path.filename().string();
......
......@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
ins->get_literal().visit([&](auto v) {
assert(v.size() == 1);
auto x = v.front();
if(std::isinf(x))
if(std::isinf(static_cast<double>(x)))
{
string_literal = "__builtin_huge_val()";
if(x < 0)
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(x))
else if(std::isnan(static_cast<double>(x)))
string_literal = "__builtin_nan()";
else
string_literal = ins->get_literal().to_string();
......
......@@ -45,7 +45,15 @@ if(NOT WIN32)
endif()
rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py)
file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output)
target_compile_definitions(driver PUBLIC MIGRAPHX_ORT_SHA1="${String_output}")
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
if(MIGRAPHX_ENABLE_PYTHON)
target_link_libraries(driver migraphx_py)
target_compile_definitions(driver PRIVATE MIGRAPHX_ENABLE_PYTHON)
endif()
rocm_install_targets(
TARGETS driver
......
......@@ -187,6 +187,13 @@ struct value_parser
}
};
// version for std::optional object
template <class T>
struct value_parser<std::optional<T>>
{
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
};
struct argument_parser
{
struct argument
......
......@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp>
......@@ -57,6 +59,13 @@ namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
inline std::string get_version()
{
return "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + "." +
std::to_string(MIGRAPHX_VERSION_MINOR) + "." + std::to_string(MIGRAPHX_VERSION_PATCH) +
"." MIGRAPHX_VERSION_TWEAK;
}
struct loader
{
std::string model;
......@@ -281,10 +290,12 @@ struct loader
options.format = "json";
p = migraphx::load(file, options);
}
#ifdef MIGRAPHX_ENABLE_PYTHON
else if(file_type == "py")
{
p = migraphx::load_py(file);
}
#endif
else if(file_type == "migraphx")
{
p = migraphx::load(file);
......@@ -475,13 +486,15 @@ struct compiler
{
if(is_offload_copy_set(p) and not co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
std::cout
<< "[WARNING]: MIGraphX program was likely compiled with offload_copy "
"set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
}
else if(co.offload_copy)
{
std::cout << "MIGraphX program was likely compiled without "
std::cout << "[WARNING]: MIGraphX program was likely compiled without "
"offload_copy set, Try "
"removing "
"`--enable-offload-copy` flag if passed to driver, if program run "
......@@ -534,13 +547,17 @@ struct params : command<params>
struct verify : command<verify>
{
compiler c;
double tolerance = 80;
std::optional<double> rms_tol;
std::optional<double> atol;
std::optional<double> rtol;
bool per_instruction = false;
bool reduce = false;
void parse(argument_parser& ap)
{
c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
ap(per_instruction,
{"-i", "--per-instruction"},
ap.help("Verify each instruction"),
......@@ -559,36 +576,34 @@ struct verify : command<verify>
auto quantize = precision::fp32;
if(c.to_fp16)
{
quantize = precision::fp16;
}
if(c.to_int8)
{
quantize = precision::int8;
}
auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;
if(per_instruction)
{
verify_instructions(p, t, c.co, quantize, tolerance);
verify_instructions(p, t, c.co, quantize, tols);
}
else if(reduce)
{
verify_reduced_program(p, t, c.co, quantize, m, tolerance);
verify_reduced_program(p, t, c.co, quantize, m, tols);
}
else
{
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance);
verify_program(c.l.file, p, t, c.co, quantize, m, tols);
}
}
};
struct version : command<version>
{
void parse(const argument_parser&) {}
void run() const
{
std::cout << "MIGraphX Version: " << MIGRAPHX_VERSION_MAJOR << "." << MIGRAPHX_VERSION_MINOR
<< "." << MIGRAPHX_VERSION_PATCH << "."
<< MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK) << std::endl;
}
};
struct compile : command<compile>
{
compiler c;
......@@ -741,16 +756,14 @@ struct main_command
}
void parse(argument_parser& ap)
{
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR) + "." +
std::to_string(MIGRAPHX_VERSION_PATCH) + "." +
MIGRAPHX_STRINGIZE(MIGRAPHX_VERSION_TWEAK);
std::string version_str = get_version();
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr,
{"-v", "--version"},
ap.help("Show MIGraphX version"),
ap.show_help(version_str));
ap(nullptr, {"--ort-sha"}, ap.help("Show MIGraphX onnx runtime SHA"));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
......@@ -793,7 +806,6 @@ using namespace migraphx::driver; // NOLINT
int main(int argc, const char* argv[])
{
std::vector<std::string> args(argv + 1, argv + argc);
// no argument, print the help infomration by default
if(args.empty())
{
......@@ -802,9 +814,28 @@ int main(int argc, const char* argv[])
auto&& m = get_commands();
auto cmd = args.front();
if(cmd == "--ort-sha")
{
std::cout << MIGRAPHX_ORT_SHA1 << std::endl;
return 0;
}
if(cmd == "-v" or cmd == "--version")
{
std::cout << get_version() << std::endl;
return 0;
}
if(m.count(cmd) > 0)
{
m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
std::string driver_invocation =
std::string(argv[0]) + " " + migraphx::to_string_range(args, " ");
std::cout << "Running [ " << get_version() << " ]: " << driver_invocation << std::endl;
m.at(cmd)(argv[0],
{args.begin() + 1, args.end()}); // run driver command found in commands map
std::cout << "[ " << get_version() << " ] Complete: " << driver_invocation << std::endl;
}
else
{
......
......@@ -30,11 +30,48 @@
#include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
});
migraphx::verify::tolerance result{};
if(has_fp16 or quantize == precision::fp16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
result.rtol = 4e-2;
}
if(rms_tol)
{
result.rms_tol = *rms_tol;
}
if(atol)
{
result.atol = *atol;
}
if(rtol)
{
result.rtol = *rtol;
}
return result;
}
std::vector<argument> run_ref(program p, const parameter_map& inputs)
{
p.compile(migraphx::make_target("ref"));
......@@ -76,15 +113,25 @@ void verify_program(const std::string& name,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
auto x = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs);
auto ref_outs = run_ref(p, inputs);
auto target_outs = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size();
std::size_t output_num = ref_outs.size();
for(std::size_t i = 0; i < output_num; ++i)
{
verify_args(name, x[i], y[i], tolerance);
if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
{
std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
<< target_outs[i].get_shape() << "}" << std::endl;
}
else
{
verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
}
}
}
......@@ -92,7 +139,7 @@ void verify_instructions(const program& prog,
const target& t,
compile_options options,
precision quantize,
double tolerance)
verify::tolerance tols)
{
const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog))
......@@ -123,8 +170,7 @@ void verify_instructions(const program& prog,
{
std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl;
verify_program(
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
}
catch(...)
{
......@@ -140,14 +186,22 @@ void verify_reduced(program p,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1);
auto last = std::prev(mm->end(), n);
mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance);
try
{
verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
}
catch(const std::exception& e)
{
std::cout << "FAILED: " << n << std::endl;
std::cout << "Exception: " << e.what() << std::endl;
}
}
void verify_reduced_program(const program& p,
......@@ -155,14 +209,20 @@ void verify_reduced_program(const program& p,
compile_options options,
precision quantize,
const parameter_map& inputs,
double tolerance)
verify::tolerance tols)
{
const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++)
for(std::size_t i = 1; i < n; i++)
{
verify_reduced(p, i, t, options, quantize, inputs, tolerance);
auto last = std::prev(mm->end(), i + 1);
if(contains({"@literal", "@param"}, last->name()))
{
std::cout << "Skip: " << i << std::endl;
continue;
}
verify_reduced(p, i, t, options, quantize, inputs, tols);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment