Commit 6711780a authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents c0563b9e d1abf06f
...@@ -27,20 +27,18 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}") ...@@ -27,20 +27,18 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "The binary and source directroy cannot be the same") message(FATAL_ERROR "The binary and source directroy cannot be the same")
endif() endif()
get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) # Setup valid strings for build type
if (NOT CMAKE_CONFIGURATION_TYPES)
set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING "Configs")
endif()
get_property(MIGRAPHX_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG)
# This has to be initialized before the project() command appears # This has to be initialized before the project() command appears
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
if(_GENERATOR_IS_MULTI_CONFIG) if(NOT MIGRAPHX_GENERATOR_IS_MULTI_CONFIG)
if (NOT CMAKE_CONFIGURATION_TYPES) set(CMAKE_BUILD_TYPE Release CACHE STRING
set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
"Available build types (configurations) on multi-config generators") set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS ${CMAKE_CONFIGURATION_TYPES})
endif()
else()
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
endif()
endif() endif()
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
...@@ -87,7 +85,7 @@ include(ROCMSetupVersion) ...@@ -87,7 +85,7 @@ include(ROCMSetupVersion)
option(BUILD_DEV "Build for development purpose only" OFF) option(BUILD_DEV "Build for development purpose only" OFF)
rocm_setup_version(VERSION 2.8.0) rocm_setup_version(VERSION 2.8.0)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -109,12 +109,15 @@ def rocmnode(name, body) { ...@@ -109,12 +109,15 @@ def rocmnode(name, body) {
} }
} }
rocmtest clang_debug: rocmnode('cdna') { cmake_build -> rocmtest clang_debug: rocmnode('mi100+') { cmake_build ->
stage('hipRTC Debug') { stage('hipRTC Debug') {
def sanitizers = "undefined" // Disable MLIR since it doesnt work with all ub sanitizers
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}" withEnv(['MIGRAPHX_DISABLE_MLIR=1']) {
def gpu_targets = getgputargets() def sanitizers = "undefined"
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'", gpu_debug: true) def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
def gpu_targets = getgputargets()
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}' -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'", gpu_debug: true)
}
} }
}, clang_release: rocmnode('mi100+') { cmake_build -> }, clang_release: rocmnode('mi100+') { cmake_build ->
stage('Hip Clang Release') { stage('Hip Clang Release') {
...@@ -126,14 +129,14 @@ rocmtest clang_debug: rocmnode('cdna') { cmake_build -> ...@@ -126,14 +129,14 @@ rocmtest clang_debug: rocmnode('cdna') { cmake_build ->
// stage('Hidden symbols') { // stage('Hidden symbols') {
// cmake_build(flags: "-DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_C_VISIBILITY_PRESET=hidden") // cmake_build(flags: "-DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_C_VISIBILITY_PRESET=hidden")
// } // }
}, all_targets_debug : rocmnode('cdna') { cmake_build -> }, all_targets_debug : rocmnode('mi100+') { cmake_build ->
stage('All targets Release') { stage('All targets Release') {
def gpu_targets = getgputargets() def gpu_targets = getgputargets()
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DMIGRAPHX_ENABLE_FPGA=On -DGPU_TARGETS='${gpu_targets}'") cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_ENABLE_GPU=On -DMIGRAPHX_ENABLE_CPU=On -DMIGRAPHX_ENABLE_FPGA=On -DGPU_TARGETS='${gpu_targets}'")
} }
}, mlir_debug: rocmnode('cdna') { cmake_build -> }, mlir_debug: rocmnode('mi100+') { cmake_build ->
stage('MLIR Debug') { stage('MLIR Debug') {
withEnv(['MIGRAPHX_ENABLE_MLIR=1']) { withEnv(['MIGRAPHX_ENABLE_EXTRA_MLIR=1']) {
def sanitizers = "undefined" def sanitizers = "undefined"
// Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS. // Note: the -fno-sanitize= is copied from upstream LLVM_UBSAN_FLAGS.
def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}" def debug_flags_cxx = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize=vptr,function -fno-sanitize-recover=${sanitizers}"
...@@ -144,7 +147,7 @@ rocmtest clang_debug: rocmnode('cdna') { cmake_build -> ...@@ -144,7 +147,7 @@ rocmtest clang_debug: rocmnode('cdna') { cmake_build ->
} }
}, ck_hiprtc: rocmnode('mi100+') { cmake_build -> }, ck_hiprtc: rocmnode('mi100+') { cmake_build ->
stage('CK hipRTC') { stage('CK hipRTC') {
withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1']) { withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1', 'MIGRAPHX_DISABLE_MLIR=1']) {
def gpu_targets = getgputargets() def gpu_targets = getgputargets()
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'") cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On -DGPU_TARGETS='${gpu_targets}'")
} }
......
...@@ -35,7 +35,7 @@ fastjsonschema==2.16.3 ...@@ -35,7 +35,7 @@ fastjsonschema==2.16.3
# via rocm-docs-core # via rocm-docs-core
gitdb==4.0.10 gitdb==4.0.10
# via gitpython # via gitpython
gitpython==3.1.32 gitpython==3.1.37
# via rocm-docs-core # via rocm-docs-core
idna==3.4 idna==3.4
# via requests # via requests
...@@ -87,7 +87,7 @@ requests==2.28.2 ...@@ -87,7 +87,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.24.2 rocm-docs-core==0.26.0
# via -r requirements.in # via -r requirements.in
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
...@@ -130,7 +130,7 @@ sphinxcontrib-serializinghtml==1.1.5 ...@@ -130,7 +130,7 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx # via sphinx
typing-extensions==4.5.0 typing-extensions==4.5.0
# via pydata-sphinx-theme # via pydata-sphinx-theme
urllib3==1.26.15 urllib3==1.26.18
# via requests # via requests
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
...@@ -28,5 +28,5 @@ ROCmSoftwarePlatform/half@rocm-5.6.0 ...@@ -28,5 +28,5 @@ ROCmSoftwarePlatform/half@rocm-5.6.0
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@a22e479b8e1557961039db2d5c5ff89cff35e86b -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCmSoftwarePlatform/composable_kernel@70eefcf4f263aa5c25f3c9ff0db8f6f199ef0fb9 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/rocMLIR@12748a3402c069f733ea7f2ba1f8d8a070b3622a -DBUILD_FAT_LIBROCKCOMPILER=On ROCmSoftwarePlatform/rocMLIR@507bb94ce7873786486d296ec81d2eadaab49003 -DBUILD_FAT_LIBROCKCOMPILER=On
\ No newline at end of file
...@@ -187,6 +187,13 @@ struct value_parser ...@@ -187,6 +187,13 @@ struct value_parser
} }
}; };
// version for std::optional object
template <class T>
struct value_parser<std::optional<T>>
{
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
};
struct argument_parser struct argument_parser
{ {
struct argument struct argument
......
...@@ -540,19 +540,17 @@ struct params : command<params> ...@@ -540,19 +540,17 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
migraphx::verify::tolerance tols; std::optional<double> rms_tol;
std::optional<double> atol;
std::optional<double> rtol;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
c.parse(ap); c.parse(ap);
ap(tols.rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error (Default: 0.001)")); ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
ap(tols.atol, ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
{"--atol"}, ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
ap.help("Tolerance for the elementwise absolute difference (Default: 0.001)"));
ap(tols.rtol,
{"--rtol"},
ap.help("Tolerance for the elementwise relative difference (Default: 0.001)"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
...@@ -571,9 +569,18 @@ struct verify : command<verify> ...@@ -571,9 +569,18 @@ struct verify : command<verify>
auto quantize = precision::fp32; auto quantize = precision::fp32;
if(c.to_fp16) if(c.to_fp16)
{
quantize = precision::fp16; quantize = precision::fp16;
}
if(c.to_int8) if(c.to_int8)
{
quantize = precision::int8; quantize = precision::int8;
}
auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;
if(per_instruction) if(per_instruction)
{ {
......
...@@ -36,6 +36,42 @@ namespace migraphx { ...@@ -36,6 +36,42 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
});
migraphx::verify::tolerance result{};
if(has_fp16 or quantize == precision::fp16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
result.rtol = 4e-2;
}
if(rms_tol)
{
result.rms_tol = *rms_tol;
}
if(atol)
{
result.atol = *atol;
}
if(rtol)
{
result.rtol = *rtol;
}
return result;
}
std::vector<argument> run_ref(program p, const parameter_map& inputs) std::vector<argument> run_ref(program p, const parameter_map& inputs)
{ {
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
......
...@@ -32,6 +32,12 @@ namespace migraphx { ...@@ -32,6 +32,12 @@ namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol);
void verify_program(const std::string& name, void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
......
...@@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument> ...@@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
{ {
argument() = default; argument() = default;
argument(const shape& s); explicit argument(const shape& s);
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d) argument(shape s, F d)
......
...@@ -57,7 +57,7 @@ struct instruction_ref : std::list<instruction>::iterator ...@@ -57,7 +57,7 @@ struct instruction_ref : std::list<instruction>::iterator
std::is_same<U, instruction_ref>{})> std::is_same<U, instruction_ref>{})>
friend bool operator!=(const T& x, const U& y) friend bool operator!=(const T& x, const U& y)
{ {
return !(x == y); return not(x == y);
} }
}; };
#else #else
......
...@@ -88,13 +88,13 @@ struct allocate ...@@ -88,13 +88,13 @@ struct allocate
{ {
if(args.empty()) if(args.empty())
{ {
return {output_shape}; return argument{output_shape};
} }
else else
{ {
std::vector<std::size_t> output_dims(output_shape.ndim()); std::vector<std::size_t> output_dims(output_shape.ndim());
args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); }); args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); });
return {shape{buf_type, output_dims}}; return argument{shape{buf_type, output_dims}};
} }
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp> #include <migraphx/dyn_output.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,12 +39,13 @@ namespace op { ...@@ -38,12 +39,13 @@ namespace op {
struct argmax struct argmax
{ {
int64_t axis = 0; int64_t axis = 0;
bool select_last_index = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.select_last_index, "select_last_index"));
} }
value attributes() const value attributes() const
...@@ -87,6 +89,10 @@ struct argmax ...@@ -87,6 +89,10 @@ struct argmax
max_val = cur_val; max_val = cur_val;
max_index = i; max_index = i;
} }
else if(select_last_index and float_equal(max_val, cur_val))
{
max_index = i;
}
} }
return max_index; return max_index;
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/float_equal.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -38,11 +39,12 @@ namespace op { ...@@ -38,11 +39,12 @@ namespace op {
struct argmin struct argmin
{ {
int64_t axis = 0; int64_t axis = 0;
bool select_last_index = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"), f(self.select_last_index, "select_last_index"));
} }
value attributes() const value attributes() const
...@@ -78,6 +80,10 @@ struct argmin ...@@ -78,6 +80,10 @@ struct argmin
min_val = cur_val; min_val = cur_val;
min_index = i; min_index = i;
} }
else if(select_last_index and float_equal(min_val, cur_val))
{
min_index = i;
}
} }
return min_index; return min_index;
......
...@@ -411,7 +411,7 @@ struct pooling ...@@ -411,7 +411,7 @@ struct pooling
// for dynamic GlobalPooling, there's no padding // for dynamic GlobalPooling, there's no padding
kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end()); kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
output_shape = dyn_out.computed_shape; output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape; result = argument{dyn_out.computed_shape};
} }
else if((padding_mode != op::padding_mode_t::default_)) else if((padding_mode != op::padding_mode_t::default_))
{ {
...@@ -439,7 +439,7 @@ struct pooling ...@@ -439,7 +439,7 @@ struct pooling
{ {
kernel_dims = this->lengths; kernel_dims = this->lengths;
output_shape = dyn_out.computed_shape; output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape; result = argument{dyn_out.computed_shape};
} }
// Perform the computation and populate result // Perform the computation and populate result
......
...@@ -36,6 +36,22 @@ namespace migraphx { ...@@ -36,6 +36,22 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* 1 input version:
* reshape(input_data)
* this.dims = output_dims
* Makes a copy of input_data to the output shape.
*
* 2 input version:
* reshape(input_data, output_buffer)
* this.dims = unset
* Copies input_data to output_buffer; output_buffer already has the output shape.
* This version will not fail gracefully if the input shape and output_buffer shape are
* incompatible. There's a throw that will catch when the number of elements do not match at
* runtime. This version should only be used for dynamic reshapes (output dimensions only known at
* runtime). If output_buffer has a static shape during compile/parse, you can use the 1 input
* version.
*/
struct reshape struct reshape
{ {
std::vector<int64_t> dims; std::vector<int64_t> dims;
...@@ -215,32 +231,56 @@ struct reshape ...@@ -215,32 +231,56 @@ struct reshape
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1); check_shapes{inputs, *this, true}.has(1, 2);
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1) if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim"); MIGRAPHX_THROW("reshape: Dimensions for reshape can only have one -1 dim");
auto s0 = inputs.front(); auto s0 = inputs.front();
if(s0.dynamic()) if(inputs.size() == 1)
{ {
return dyn_compute_shape(s0); if(s0.dynamic())
{
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
} }
else else
{ {
return static_compute_shape(inputs, n_neg_dims); return inputs.back();
} }
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
assert(dyn_out.computed_shape.standard()); assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape}; if(args.size() == 1)
{
argument result{dyn_out.computed_shape};
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin()); std::copy(input.begin(), input.end(), output.begin());
}); });
return result; return result;
}
else
{
// 2 arg
if(args[0].get_shape().elements() != args[1].get_shape().elements())
{
MIGRAPHX_THROW("Reshape: Number of elements must match at runtime. Input: " +
std::to_string(args[0].get_shape().elements()) +
" Output buffer: " + std::to_string(args[1].get_shape().elements()));
}
visit_all(args[1], args[0])([&](auto output, auto input) {
std::copy(input.begin(), input.end(), output.begin());
});
return args[1];
}
} }
}; };
......
...@@ -66,7 +66,7 @@ struct scatter : op_name<Derived> ...@@ -66,7 +66,7 @@ struct scatter : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).standard(); check_shapes{inputs, *this}.has(3);
// If non-packed, this converts to a packed output while preserving permutation of tensor // If non-packed, this converts to a packed output while preserving permutation of tensor
return inputs.front().with_lens(inputs.front().lens()); return inputs.front().with_lens(inputs.front().lens());
} }
......
...@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info, ...@@ -47,7 +47,7 @@ void cal_auto_padding_size(onnx_parser::node_info info,
return; return;
} }
auto auto_pad = info.attributes["auto_pad"].s(); auto auto_pad = to_upper(info.attributes["auto_pad"].s());
if(auto_pad.find("SAME") != std::string::npos) if(auto_pad.find("SAME") != std::string::npos)
{ {
bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos); bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op> ...@@ -50,14 +50,25 @@ struct parse_arg_op : op_parser<parse_arg_op>
keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>(); keep_dims = parser.parse_value(info.attributes.at("keepdims")).at<int>();
} }
bool select_last_index = false;
if(contains(info.attributes, "select_last_index"))
{
select_last_index = static_cast<bool>(
parser.parse_value(info.attributes.at("select_last_index")).at<int>());
}
if(keep_dims == 0) if(keep_dims == 0)
{ {
auto ins = info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); auto ins = info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins); return info.add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
} }
else else
{ {
return info.add_instruction(make_op(opd.op_name, {{"axis", axis}}), args); return info.add_instruction(
make_op(opd.op_name, {{"axis", axis}, {"select_last_index", select_last_index}}),
args);
} }
} }
}; };
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_groupnorm : op_parser<parse_groupnorm>
{
std::vector<op_desc> operators() const { return {{"GroupNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
size_t num_groups;
if(contains(info.attributes, "num_groups"))
{
num_groups = parser.parse_value(info.attributes.at("num_groups")).at<size_t>();
}
else
{
MIGRAPHX_THROW("PARSE_GROUPNORM: num_groups must be available");
}
if(args.size() != 3)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
auto bias = args.at(2);
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
auto x_dims = x_shape.lens();
if(x_shape.ndim() <= 2)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: invalid input shape");
}
auto c = x_shape.lens().at(1);
if(c % num_groups != 0)
{
MIGRAPHX_THROW(
"PARSE_GROUPNORM: num_groups should be a divisor of the number of channels");
}
auto group_size = c / num_groups;
if(scale->get_shape().ndim() != 1 or scale->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: scale tensor shape should be num_groups");
}
if(bias->get_shape().ndim() != 1 or bias->get_shape().lens().at(0) != num_groups)
{
MIGRAPHX_THROW("PARSE_GROUPNORM: bias tensor shape should be num_groups");
}
// Original shape: N x C x D1 x ... x Dn
// New shape: N x num_groups x C // num_groups x D1 x ... x Dn
std::vector<size_t> dims = {x_dims.at(0), num_groups, group_size};
std::copy(x_dims.begin() + 2, x_dims.end(), std::back_inserter(dims));
auto x_reshaped = info.add_instruction(make_op("reshape", {{"dims", dims}}), x);
// Axes for D1 x ... x Dn
std::vector<size_t> axes(dims.size() - 2);
std::iota(axes.begin(), axes.end(), 2);
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_reshaped);
auto x_sub_mean = info.add_common_op("sub", x_reshaped, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x_reshaped, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
auto scale_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), scale);
auto bias_bcast =
info.add_instruction(make_op("broadcast", {{"axis", 1}, {"out_lens", dims}}), bias);
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = info.add_instruction(make_op("add"), scaled, bias_bcast);
auto y_reshaped = info.add_instruction(make_op("reshape", {{"dims", x_dims}}), y);
return y_reshaped;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: LAYERNORM does not support stash_type, it will be ignored.\n";
}
if(args.size() < 2 or args.size() > 3)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input count");
}
auto x = args.at(0);
auto scale = args.at(1);
bool skip_bias = args.size() == 2;
instruction_ref bias;
if(not skip_bias)
{
bias = args.at(2);
}
auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
if(x_rank < 2)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid input shape");
}
// If rank(X) is r, axis' allowed range is [-r, r)
if(axis < -x_rank or axis >= x_rank)
{
MIGRAPHX_THROW("PARSE_LAYERNORM: invalid axis");
}
// y = (x - mean) * rsqrt(variance + epsilon) * scale + bias
// mean = reduce_mean({D1, D2, ... Dk}, x)
// variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
// axis can be negative
axis = axis < 0 ? axis + x_rank : axis;
auto kdims = x_rank - axis;
std::vector<int64_t> axes(kdims);
std::iota(axes.begin(), axes.end(), axis);
auto skipped_axes = x_rank - kdims;
auto mean = info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
auto x_sub_mean = info.add_common_op("sub", x, mean);
auto x_sqdiff_mean = info.add_common_op("sqdiff", x, mean);
auto variance =
info.add_instruction(make_op("reduce_mean", {{"axes", axes}}), x_sqdiff_mean);
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
auto var_eps = info.add_common_op("add", variance, eps);
auto rsqrt = info.add_instruction(make_op("rsqrt"), var_eps);
auto result = info.add_common_op("mul", x_sub_mean, rsqrt);
instruction_ref scale_bcast = scale;
instruction_ref bias_bcast = bias;
if(skipped_axes > 0)
{
auto x_dims = x_shape.lens();
scale_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), scale);
if(not skip_bias)
{
bias_bcast = info.add_instruction(
make_op("broadcast", {{"axis", skipped_axes}, {"out_lens", x_dims}}), bias);
}
}
auto scaled = info.add_instruction(make_op("mul"), result, scale_bcast);
auto y = skip_bias ? scaled : info.add_instruction(make_op("add"), scaled, bias_bcast);
return {y, mean, rsqrt};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment