Commit bb390b65 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'activationOperators' into addLogExpOperators

parents adfa1a93 682ad83a
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -19,7 +19,7 @@ struct fwd_conv_batchnorm_rewrite
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/literal.hpp>
......@@ -8,9 +8,9 @@
#include <random>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class T, MIGRAPH_REQUIRES(is_floating_point<T>{})>
template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
{
if(z == 0)
......@@ -22,7 +22,7 @@ constexpr T normalize(unsigned long z)
return T(result);
}
template <class T, MIGRAPH_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
......@@ -30,7 +30,7 @@ constexpr T normalize(unsigned long z)
return half_max - (z % max);
}
template <class T, MIGRAPH_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
......@@ -93,7 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l);
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,14 +5,14 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
using half = half_float::half;
......@@ -33,7 +33,7 @@ struct deduce<half_float::detail::expr>
template <class T>
using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <migraphx/literal.hpp>
#include <migraphx/shape.hpp>
......@@ -11,7 +11,7 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
......@@ -90,7 +90,7 @@ struct instruction
std::vector<instruction_ref> arguments;
literal lit;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
namespace std {
......
#ifndef MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPH_GUARD_INSTRUCTION_REF_HPP
#ifndef MIGRAPHX_GUARD_INSTRUCTION_REF_HPP
#define MIGRAPHX_GUARD_INSTRUCTION_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPHX_GUARD_RTGLIB_ITERATOR_FOR_HPP
#include <cassert>
#include <type_traits>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct iterator_for_range
......@@ -39,7 +39,7 @@ iterator_for_range<T> iterator_for(T& x)
return {&x};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_LITERAL_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_LITERAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_LITERAL_HPP
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
......@@ -12,7 +12,7 @@
#include <memory>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
/**
* @brief Represents a raw literal
......@@ -124,7 +124,7 @@ literal transform(literal l1, literal l2, F f)
return result;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_MAKE_SHARED_ARRAY_HPP
#include <memory>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <typename T>
std::shared_ptr<T> make_shared_array(size_t size)
......@@ -13,7 +13,7 @@ std::shared_ptr<T> make_shared_array(size_t size)
return std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#define MIGRAPH_GUARD_MIGRAPH_MANAGE_PTR_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHX_MANAGE_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_MANAGE_PTR_HPP
#include <memory>
#include <type_traits>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class F, F f> // NOLINT
struct manage_deleter
......@@ -51,10 +51,10 @@ shared<T> share(T p)
return shared<T>{std::move(p)};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#define MIGRAPH_MANAGE_PTR(T, F) \
#define MIGRAPHX_MANAGE_PTR(T, F) \
migraphx::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_MATCHER_HPP
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
......@@ -10,7 +10,7 @@
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
......@@ -169,7 +169,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...) \
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
......@@ -178,7 +178,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \
#define MIGRAPHX_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
......@@ -266,22 +266,22 @@ auto any_of(Ts... ms)
});
}
MIGRAPH_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPH_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPH_PRED_MATCHER(broadcast_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
MIGRAPH_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
}
MIGRAPH_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
......@@ -340,7 +340,7 @@ inline auto either_arg(std::size_t i, std::size_t j)
}
} // namespace match
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPH_GUARD_RTGLIB_MEMORY_COLORING_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_HPP
#define MIGRAPHX_GUARD_RTGLIB_MEMORY_COLORING_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
......@@ -20,7 +20,7 @@ struct memory_coloring
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_ONNX_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ONNX_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_ONNX_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
/// Create a program from an onnx file
program parse_onnx(const std::string& name);
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string>
......@@ -16,7 +16,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
......@@ -103,7 +103,7 @@ template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
std::string name = x.name();
MIGRAPH_THROW("Not computable: " + name);
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
......@@ -387,7 +387,7 @@ inline bool operator!=(const operation& x, const operation& y) { return !(x == y
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
#include <array>
#include <migraphx/operation.hpp>
......@@ -11,14 +11,14 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct not_computable
{
argument compute(context&, const shape&, const std::vector<argument>&) const
{
MIGRAPH_THROW("not computable");
MIGRAPHX_THROW("not computable");
}
};
......@@ -124,7 +124,7 @@ struct convolution
}
else
{
MIGRAPH_THROW("Invalid padding mode");
MIGRAPHX_THROW("Invalid padding mode");
}
}
};
......@@ -163,7 +163,7 @@ struct im2col
auto kernel_width = weights.lens()[3];
check_shapes{inputs, *this}.has(2);
if(batch_size != 1)
MIGRAPH_THROW("im2col only support batch_size 1");
MIGRAPHX_THROW("im2col only support batch_size 1");
auto output_height = std::size_t(std::max<std::ptrdiff_t>(
1,
(input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
......@@ -279,13 +279,13 @@ struct transpose
auto t = input.type();
if(dims.size() != input_lens.size())
{
MIGRAPH_THROW("Permutation has wrong number of axes");
MIGRAPHX_THROW("Permutation has wrong number of axes");
}
std::vector<int64_t> axes(dims.size());
std::iota(axes.begin(), axes.end(), 0);
if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
{
MIGRAPH_THROW("Invalid permutation");
MIGRAPHX_THROW("Invalid permutation");
}
std::vector<size_t> output_lens(input_lens.size());
std::vector<size_t> output_strides(input_lens.size());
......@@ -303,6 +303,12 @@ struct transpose
int output_alias(const std::vector<shape>&) const { return 0; }
};
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
struct contiguous
{
std::string name() const { return "contiguous"; }
......@@ -336,7 +342,7 @@ struct concat
{
if(inputs.empty())
{
MIGRAPH_THROW("Number of input tensors should exceed 0");
MIGRAPHX_THROW("Number of input tensors should exceed 0");
}
const auto& first_shape_lens = inputs.front().lens();
......@@ -349,7 +355,7 @@ struct concat
return s.lens()[l] == first_shape_lens[l];
}))
{
MIGRAPH_THROW("Non-axis dimensions should match");
MIGRAPHX_THROW("Non-axis dimensions should match");
}
}
}
......@@ -418,18 +424,9 @@ struct slice
auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
// std::vector<int64_t> t_axes(old_lens.size());
// if(axes.size() == 0)
// {
// std::iota(t_axes.begin(), t_axes.end(), 0);
// }
// else
// {
// std::copy(axes.begin(), axes.end(), t_axes.begin());
// }
if(starts.size() != axes.size() || axes.size() != ends.size())
{
MIGRAPH_THROW("inconsistent sizes");
MIGRAPHX_THROW("inconsistent sizes");
}
std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++)
......@@ -468,7 +465,7 @@ struct squeeze
if(std::any_of(
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
{
MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
}
std::vector<std::size_t> new_lens;
if(axes.empty())
......@@ -554,7 +551,7 @@ struct reshape
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
if(n_neg_dims > 1)
MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim");
for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == 0)
......@@ -578,7 +575,7 @@ struct reshape
}
shape s{inputs.front().type(), rdims};
if(s.elements() != inputs.front().elements())
MIGRAPH_THROW("Wrong number of elements for reshape");
MIGRAPHX_THROW("Wrong number of elements for reshape");
return s;
}
argument compute(context&, shape output_shape, std::vector<argument> args) const
......@@ -608,8 +605,8 @@ struct dot
auto t = a.type();
if(a.lens()[1] != b.lens()[0])
MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
to_string_range(b.lens()) + "}");
MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
return {t, {a.lens()[0], b.lens()[1]}};
}
};
......@@ -737,7 +734,7 @@ struct flatten
if(axis > lens.size())
{
MIGRAPH_THROW("axis for flatten must be less than tensor rank");
MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
}
auto x =
std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
......@@ -751,6 +748,15 @@ struct flatten
}
int output_alias(const std::vector<shape>&) const { return 0; }
};
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
struct broadcast
{
uint64_t axis = 0;
......@@ -775,7 +781,7 @@ struct broadcast
}))
{
if(axis != 0)
MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
}
else
......@@ -783,7 +789,7 @@ struct broadcast
assert(broadcast_shape.lens().size() - axis >= input.lens().size());
if(!std::equal(
input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
MIGRAPH_THROW("when broadcasting success sizes must match");
MIGRAPHX_THROW("when broadcasting success sizes must match");
std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
return {t, broadcast_shape.lens(), std::move(bcast_strides)};
}
......@@ -814,10 +820,10 @@ struct multibroadcast
auto input = inputs.at(0);
if(input.lens().empty())
MIGRAPH_THROW("inputs dimensions should be > 0");
MIGRAPHX_THROW("inputs dimensions should be > 0");
if(input.lens().size() > output_lens.size())
MIGRAPH_THROW("inputs dimensions should <= output size");
MIGRAPHX_THROW("inputs dimensions should <= output size");
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
......@@ -937,7 +943,7 @@ struct outline
};
} // namespace op
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#ifndef MIGRAPHX_GUARD_PASS_HPP
#define MIGRAPHX_GUARD_PASS_HPP
#include <cassert>
#include <string>
......@@ -10,7 +10,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -218,7 +218,7 @@ inline const ValueType& any_cast(const pass& x)
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_PASS_CONFIG_HPP
#define MIGRAPH_GUARD_PASS_CONFIG_HPP
#ifndef MIGRAPHX_GUARD_PASS_CONFIG_HPP
#define MIGRAPHX_GUARD_PASS_CONFIG_HPP
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_DISABLE_MEMORY_COLORING)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MEMORY_COLORING)
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPH_GUARD_PASS_CONFIG_HPP
#endif // MIGRAPHX_GUARD_PASS_CONFIG_HPP
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_PROGRAM_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PROGRAM_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PROGRAM_HPP
#include <list>
#include <unordered_map>
......@@ -14,7 +14,7 @@
#include <iostream>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program_impl;
......@@ -109,7 +109,7 @@ struct program
std::unique_ptr<program_impl> impl;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
#include <initializer_list>
......@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
......@@ -106,7 +106,7 @@ iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
return {p.first, p.second};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_RANK_HPP
#define MIGRAPH_GUARD_RTGLIB_RANK_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_RANK_HPP
#define MIGRAPHX_GUARD_RTGLIB_RANK_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <int N>
struct rank : rank<N - 1>
......@@ -16,7 +16,7 @@ struct rank<0>
{
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RAW_DATA_HPP
#define MIGRAPH_GUARD_RAW_DATA_HPP
#ifndef MIGRAPHX_GUARD_RAW_DATA_HPP
#define MIGRAPHX_GUARD_RAW_DATA_HPP
#include <migraphx/tensor_view.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct raw_data_base
{
......@@ -126,7 +126,7 @@ struct raw_data : raw_data_base
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != migraphx::shape::get_type<T>{})
MIGRAPH_THROW("Incorrect data type for raw data");
MIGRAPHX_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
......@@ -143,8 +143,8 @@ struct raw_data : raw_data_base
template <class T,
class U,
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
auto&& xshape = x.get_shape();
......@@ -166,8 +166,8 @@ bool operator==(const T& x, const U& y)
template <class T,
class U,
MIGRAPH_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
MIGRAPHX_REQUIRES(std::is_base_of<raw_data_base, T>{} &&
std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
return !(x == y);
......@@ -198,14 +198,14 @@ auto visit_all(T&& x, Ts&&... xs)
auto&& s = x.get_shape();
std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
MIGRAPH_THROW("Types must be the same");
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
// Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
detail::visit_all_impl(s, v, x, xs...);
};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment