"src/targets/gpu/device/reduce_min.cpp" did not exist on "0c1df49c7252cca61b705cf411fa92c2b4a7828e"
Unverified Commit 40fbef9b authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'develop' into threaded_nms

parents d164b151 aeb9f78c
...@@ -35,7 +35,7 @@ struct module; ...@@ -35,7 +35,7 @@ struct module;
/** /**
* Rewrite pooling to reduce_mean * Rewrite pooling to reduce_mean
*/ */
struct rewrite_pooling struct MIGRAPHX_EXPORT rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -35,7 +35,7 @@ struct module; ...@@ -35,7 +35,7 @@ struct module;
/** /**
* Rewrite quantization ops to equivalent operators * Rewrite quantization ops to equivalent operators
*/ */
struct rewrite_quantization struct MIGRAPHX_EXPORT rewrite_quantization
{ {
std::string name() const { return "rewrite_quantization"; } std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -39,7 +39,7 @@ struct module; ...@@ -39,7 +39,7 @@ struct module;
/** /**
* Rewrite rnn to gemm and add. * Rewrite rnn to gemm and add.
*/ */
struct rewrite_rnn struct MIGRAPHX_EXPORT rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -37,7 +37,7 @@ struct module; ...@@ -37,7 +37,7 @@ struct module;
/** /**
* Schedule instructions for concurrent execution * Schedule instructions for concurrent execution
*/ */
struct schedule struct MIGRAPHX_EXPORT schedule
{ {
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
......
...@@ -63,7 +63,7 @@ struct schedule_model ...@@ -63,7 +63,7 @@ struct schedule_model
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct schedule_model struct MIGRAPHX_EXPORT schedule_model
{ {
// //
std::size_t concurrency() const; std::size_t concurrency() const;
...@@ -99,7 +99,7 @@ struct schedule_model ...@@ -99,7 +99,7 @@ struct schedule_model
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -274,7 +274,7 @@ struct schedule_model ...@@ -274,7 +274,7 @@ struct schedule_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -43,7 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -43,7 +43,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct value; struct value;
struct shape_impl; struct shape_impl;
struct shape struct MIGRAPHX_EXPORT shape
{ {
// Add new types here // Add new types here
...@@ -85,7 +85,7 @@ struct shape ...@@ -85,7 +85,7 @@ struct shape
{ {
}; };
struct dynamic_dimension struct MIGRAPHX_EXPORT dynamic_dimension
{ {
std::size_t min = 0; std::size_t min = 0;
std::size_t max = 0; std::size_t max = 0;
...@@ -100,22 +100,28 @@ struct shape ...@@ -100,22 +100,28 @@ struct shape
bool is_fixed() const; bool is_fixed() const;
bool has_optimal() const; bool has_optimal() const;
friend bool operator==(const dynamic_dimension& x, const dynamic_dimension& y); MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x,
friend bool operator!=(const dynamic_dimension& x, const dynamic_dimension& y); const dynamic_dimension& y);
friend std::ostream& operator<<(std::ostream& os, const dynamic_dimension& x); MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x,
const dynamic_dimension& y);
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os,
const dynamic_dimension& x);
// compare to fixed std::size_t dimension // compare to fixed std::size_t dimension
friend bool operator==(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x, const std::size_t& y);
friend bool operator==(const std::size_t& x, const dynamic_dimension& y); MIGRAPHX_EXPORT friend bool operator==(const std::size_t& x, const dynamic_dimension& y);
friend bool operator!=(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend bool operator!=(const dynamic_dimension& x, const std::size_t& y);
friend bool operator!=(const std::size_t& x, const dynamic_dimension& y); MIGRAPHX_EXPORT friend bool operator!=(const std::size_t& x, const dynamic_dimension& y);
// add and subtract fixed std::size_t dimension // add and subtract fixed std::size_t dimension
dynamic_dimension& operator+=(const std::size_t& x); dynamic_dimension& operator+=(const std::size_t& x);
dynamic_dimension& operator-=(const std::size_t& x); dynamic_dimension& operator-=(const std::size_t& x);
friend dynamic_dimension operator+(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend dynamic_dimension operator+(const dynamic_dimension& x,
friend dynamic_dimension operator+(const std::size_t& x, const dynamic_dimension& y); const std::size_t& y);
friend dynamic_dimension operator-(const dynamic_dimension& x, const std::size_t& y); MIGRAPHX_EXPORT friend dynamic_dimension operator+(const std::size_t& x,
const dynamic_dimension& y);
MIGRAPHX_EXPORT friend dynamic_dimension operator-(const dynamic_dimension& x,
const std::size_t& y);
}; };
static const std::vector<type_t>& types(); static const std::vector<type_t>& types();
...@@ -156,14 +162,34 @@ struct shape ...@@ -156,14 +162,34 @@ struct shape
shape(const std::vector<shape>& subs); shape(const std::vector<shape>& subs);
/**
* Creates an output shape with dimensions equal to the input lengths and strides determined
* by the permutation argument such that find_permutation() of the output shape returns the
* inputted permuation.
*
* 2D example:
* parameters:
* l = [2, 3], perm = [1, 0]
* therefore:
* "original" shape = {lens = [3, 2], strides = [2, 1]}
* output_shape = {lens = [2, 3], strides = [1, 2]
*
* 3D example:
* parameters:
* l = [2, 3, 4], perm = [1, 2, 0]
* therefore:
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
*/
static shape static shape
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm); from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
/*! /*!
* The number of dimensions in the shape. * The number of dimensions in the shape, either static or dynamic.
* Same as the number of indices required to get a data value. * Same as the number of indices required to get a data value.
*/ */
std::size_t ndim() const; std::size_t ndim() const;
...@@ -214,6 +240,10 @@ struct shape ...@@ -214,6 +240,10 @@ struct shape
template <class Iterator> template <class Iterator>
std::size_t index(Iterator start, Iterator last) const std::size_t index(Iterator start, Iterator last) const
{ {
if(this->dynamic())
{
MIGRAPHX_THROW("SHAPE: index() called on dynamic shape");
}
assert(std::distance(start, last) <= this->lens().size()); assert(std::distance(start, last) <= this->lens().size());
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT
...@@ -222,11 +252,15 @@ struct shape ...@@ -222,11 +252,15 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const; /// Map element index to multi-dimensional index
void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const; std::vector<std::size_t> multi(std::size_t idx) const;
/// Map element index to multi-dimensional index and put them them into location provided by
/// pointers
void multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const;
/// Returns true if the shape is packed (number of elements and buffer size the same) with no /// Returns true if the shape is packed (number of elements and buffer size the same) with
/// padding /// no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
...@@ -262,9 +296,9 @@ struct shape ...@@ -262,9 +296,9 @@ struct shape
// convert the shape to a static one setting any non-fixed dynamic_dimensions to x // convert the shape to a static one setting any non-fixed dynamic_dimensions to x
shape to_static(std::size_t x) const; shape to_static(std::size_t x) const;
friend bool operator==(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y); MIGRAPHX_EXPORT friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const shape& x);
template <class T> template <class T>
struct as struct as
...@@ -275,6 +309,8 @@ struct shape ...@@ -275,6 +309,8 @@ struct shape
type min() const { return std::numeric_limits<type>::lowest(); } type min() const { return std::numeric_limits<type>::lowest(); }
type nan() const { return std::numeric_limits<type>::quiet_NaN(); }
template <class U> template <class U>
type operator()(U u) const type operator()(U u) const
{ {
...@@ -370,8 +406,8 @@ struct shape ...@@ -370,8 +406,8 @@ struct shape
std::shared_ptr<const shape_impl> impl; std::shared_ptr<const shape_impl> impl;
}; };
void migraphx_to_value(value& v, const shape& s); MIGRAPHX_EXPORT void migraphx_to_value(value& v, const shape& s);
void migraphx_from_value(const value& v, shape& s); MIGRAPHX_EXPORT void migraphx_from_value(const value& v, shape& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -35,7 +35,7 @@ struct module; ...@@ -35,7 +35,7 @@ struct module;
/** /**
* Simplify many algebraic instructions to more efficient versions. * Simplify many algebraic instructions to more efficient versions.
*/ */
struct simplify_algebra struct MIGRAPHX_EXPORT simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -36,7 +36,7 @@ struct module; ...@@ -36,7 +36,7 @@ struct module;
* Inserts quantized operators in place of dq->quantizable_op->q * Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs) * then removes remaining fake quantization (q->dq pairs)
*/ */
struct simplify_qdq struct MIGRAPHX_EXPORT simplify_qdq
{ {
std::string name() const { return "simplify_qdq"; } std::string name() const { return "simplify_qdq"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -36,7 +36,7 @@ struct module; ...@@ -36,7 +36,7 @@ struct module;
/** /**
* Eliminate redundant reshapes. * Eliminate redundant reshapes.
*/ */
struct simplify_reshapes struct MIGRAPHX_EXPORT simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& m) const; void apply(module& m) const;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
#include <migraphx/config.hpp>
#if defined(CPPCHECK)
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#elif defined(__has_include)
#if __has_include(<source_location>) && __cplusplus >= 202003L
#define MIGRAPHX_HAS_SOURCE_LOCATION 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#endif
#if __has_include(<experimental/source_location>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 1
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#else
#define MIGRAPHX_HAS_SOURCE_LOCATION 0
#define MIGRAPHX_HAS_SOURCE_LOCATION_TS 0
#endif
#if MIGRAPHX_HAS_SOURCE_LOCATION
#include <source_location>
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
#include <experimental/source_location>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_SOURCE_LOCATION
using source_location = std::source_location;
#elif MIGRAPHX_HAS_SOURCE_LOCATION_TS
using source_location = std::experimental::source_location;
#else
struct source_location
{
static constexpr source_location current() noexcept { return source_location{}; }
constexpr std::uint_least32_t line() const noexcept { return 0; }
constexpr std::uint_least32_t column() const noexcept { return 0; }
constexpr const char* file_name() const noexcept { return ""; }
constexpr const char* function_name() const noexcept { return ""; }
};
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_SOURCE_LOCATION_HPP
...@@ -36,7 +36,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,7 +36,7 @@ inline namespace MIGRAPHX_INLINE_NS {
* Split dynamic dimension over submodules if exactly one dimension in the parameter list is * Split dynamic dimension over submodules if exactly one dimension in the parameter list is
* dynamic. * dynamic.
*/ */
struct split_single_dyn_dim struct MIGRAPHX_EXPORT split_single_dyn_dim
{ {
std::string name() const { return "split_single_dyn_dim"; } std::string name() const { return "split_single_dyn_dim"; }
void apply(module_pass_manager&) const; void apply(module_pass_manager&) const;
......
...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,7 +35,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct sqlite_impl; struct sqlite_impl;
struct sqlite struct MIGRAPHX_EXPORT sqlite
{ {
sqlite() = default; sqlite() = default;
static sqlite read(const fs::path& p); static sqlite read(const fs::path& p);
......
...@@ -62,7 +62,7 @@ struct stream_model ...@@ -62,7 +62,7 @@ struct stream_model
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct stream_model struct MIGRAPHX_EXPORT stream_model
{ {
// //
std::size_t get_nstream() const; std::size_t get_nstream() const;
...@@ -100,7 +100,7 @@ struct stream_model ...@@ -100,7 +100,7 @@ struct stream_model
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -288,7 +288,7 @@ struct stream_model ...@@ -288,7 +288,7 @@ struct stream_model
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
...@@ -45,6 +45,8 @@ ...@@ -45,6 +45,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct value;
#ifdef DOXYGEN #ifdef DOXYGEN
/// An interface for a compilation target /// An interface for a compilation target
...@@ -125,7 +127,7 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric) ...@@ -125,7 +127,7 @@ supported_segments target_find_supported(T&, const_module_ref, support_metric)
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for: // Type-erased interface for:
struct target struct MIGRAPHX_EXPORT target
{ {
// //
std::string name() const; std::string name() const;
...@@ -165,7 +167,7 @@ struct target ...@@ -165,7 +167,7 @@ struct target
{ {
using std::swap; using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>(); auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique()) if(derived and private_detail_te_handle_mem_var.use_count() == 1)
{ {
*derived = std::forward<PrivateDetailTypeErasedT>(value); *derived = std::forward<PrivateDetailTypeErasedT>(value);
} }
...@@ -426,7 +428,7 @@ struct target ...@@ -426,7 +428,7 @@ struct target
private_detail_te_handle_base_type& private_detail_te_get_handle() private_detail_te_handle_base_type& private_detail_te_get_handle()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if(not private_detail_te_handle_mem_var.unique()) if(private_detail_te_handle_mem_var.use_count() > 1)
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x) ...@@ -467,6 +469,9 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
void migraphx_to_value(value& v, const target& t);
void migraphx_from_value(const value& v, target& t);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/tf/export.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +42,10 @@ struct tf_options ...@@ -41,7 +42,10 @@ struct tf_options
}; };
/// Create a program from a tf pb file (default is nhwc format) /// Create a program from a tf pb file (default is nhwc format)
program parse_tf(const std::string& name, const tf_options& options = tf_options{}); MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name,
const tf_options& options = tf_options{});
MIGRAPHX_TF_EXPORT std::vector<std::string> get_tf_operators();
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct tmp_dir struct MIGRAPHX_EXPORT tmp_dir
{ {
fs::path path; fs::path path;
tmp_dir(const std::string& prefix = ""); tmp_dir(const std::string& prefix = "");
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <cstdint>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <tuple> #include <tuple>
...@@ -140,7 +141,7 @@ To try_convert_value(const From& x) ...@@ -140,7 +141,7 @@ To try_convert_value(const From& x)
return detail::try_convert_value_impl<To>(rank<3>{}, x); return detail::try_convert_value_impl<To>(rank<3>{}, x);
} }
struct value struct MIGRAPHX_EXPORT value
{ {
// clang-format off // clang-format off
#define MIGRAPHX_VISIT_VALUE_TYPES(m) \ #define MIGRAPHX_VISIT_VALUE_TYPES(m) \
...@@ -392,8 +393,8 @@ struct value ...@@ -392,8 +393,8 @@ struct value
return; \ return; \
} }
MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE) MIGRAPHX_VISIT_VALUE_TYPES(MIGRAPHX_VALUE_GENERATE_CASE_VALUE)
MIGRAPHX_VALUE_GENERATE_CASE(array, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(array, )
MIGRAPHX_VALUE_GENERATE_CASE(object, ) MIGRAPHX_VALUE_GENERATE_CASE_VALUE(object, )
} }
MIGRAPHX_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
...@@ -452,14 +453,16 @@ struct value ...@@ -452,14 +453,16 @@ struct value
std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()}); std::vector<literal_to_string<To>>{default_value.begin(), default_value.end()});
} }
friend bool operator==(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator==(const value& x, const value& y);
friend bool operator!=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator!=(const value& x, const value& y);
friend bool operator<(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator<(const value& x, const value& y);
friend bool operator<=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator<=(const value& x, const value& y);
friend bool operator>(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator>(const value& x, const value& y);
friend bool operator>=(const value& x, const value& y); MIGRAPHX_EXPORT friend bool operator>=(const value& x, const value& y);
friend std::ostream& operator<<(std::ostream& os, const value& d); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const value& d);
std::size_t hash() const;
void debug_print(bool show_type = false) const; void debug_print(bool show_type = false) const;
...@@ -481,4 +484,15 @@ struct value ...@@ -481,4 +484,15 @@ struct value
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
namespace std {
template <>
struct hash<migraphx::value>
{
using argument_type = migraphx::value;
using result_type = std::size_t;
result_type operator()(const migraphx::value& x) const { return x.hash(); }
};
} // namespace std
#endif #endif
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace verify {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out ...@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return error <= threshold; return error <= threshold;
} }
} // namespace verify
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_EXPORT
bool verify_args(const std::string& name, bool verify_args(const std::string& name,
const argument& ref_arg, const argument& ref_arg,
const argument& target_arg, const argument& target_arg,
......
...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r) ...@@ -64,10 +64,7 @@ void instruction::replace(const shape& r)
result = r; result = r;
for(auto&& ins : output) for(auto&& ins : output)
{ {
if(ins->name() == "@return") assert(ins->name() == "@return" or ins->name().front() != '@');
continue;
assert(ins->name().front() != '@');
ins->recompute_shape(); ins->recompute_shape();
} }
} }
...@@ -122,10 +119,6 @@ bool instruction::valid() const ...@@ -122,10 +119,6 @@ bool instruction::valid() const
{ {
computed = result; computed = result;
} }
else if(op.name() == "@return")
{
computed = {};
}
else else
{ {
try try
...@@ -145,6 +138,7 @@ bool instruction::valid() const ...@@ -145,6 +138,7 @@ bool instruction::valid() const
} }
shape instruction::get_shape() const { return result; } shape instruction::get_shape() const { return result; }
const literal& instruction::get_literal() const const literal& instruction::get_literal() const
{ {
assert(op.name() == "@literal"); assert(op.name() == "@literal");
...@@ -406,6 +400,9 @@ void instruction::print(std::ostream& os, ...@@ -406,6 +400,9 @@ void instruction::print(std::ostream& os,
// skip return instruction shape // skip return instruction shape
if(ins->name() != "@return") if(ins->name() != "@return")
os << " -> " << ins->get_shape(); os << " -> " << ins->get_shape();
// print tid
os << ", target_id=" << ins->target_id;
} }
static void debug_name(std::ostream& os, const instruction& ins) static void debug_name(std::ostream& os, const instruction& ins)
...@@ -464,11 +461,14 @@ operation instruction::normalized_operator() const ...@@ -464,11 +461,14 @@ operation instruction::normalized_operator() const
if(this->need_normalization()) if(this->need_normalization())
{ {
auto s = this->inputs().front()->get_shape(); auto s = this->inputs().front()->get_shape();
if(not normalize_attributes(o, s.max_lens())) if(not normalize_attributes(o, s))
return this->get_operator(); return this->get_operator();
} }
return o; return o;
} }
std::size_t instruction::get_target_id() const { return target_id; }
void instruction::set_target_id(std::size_t tid) { this->target_id = tid; }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args) std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment