Commit bc5d7f75 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 47c0854d a5b0afa0
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite rnn to gemm and add.
*/
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
......@@ -7,12 +7,12 @@
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
#include <migraph/half.hpp>
#include <migraph/config.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl;
......@@ -21,7 +21,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
......@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
{
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
};
#undef MIGRAPH_SHAPE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
template <class T, class = void>
struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \
#define MIGRAPHX_SHAPE_GENERATE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE)
#undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
......@@ -62,6 +62,19 @@ struct shape
shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{
}
template <class Range1, class Range2>
shape(type_t t, const Range1& l, const Range2& s)
: shape(t,
std::vector<std::size_t>(l.begin(), l.end()),
std::vector<std::size_t>(s.begin(), s.end()))
{
}
type_t type() const;
const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const;
......@@ -141,6 +154,8 @@ struct shape
{
return reinterpret_cast<const T*>(buffer) + n;
}
type_t type_enum() const { return get_type<T>{}; }
};
template <class Visitor>
......@@ -148,12 +163,20 @@ struct shape
{
switch(this->type())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
}
MIGRAPH_THROW("Unknown type");
MIGRAPHX_THROW("Unknown type");
}
template <class Visitor>
static void visit_types(Visitor v)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
......@@ -163,7 +186,7 @@ struct shape
std::string type_string() const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraph/shape.hpp>
#include <migraph/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
void shape_for_each(const migraph::shape& s, F f)
void shape_for_each(const migraphx::shape& s, F f)
{
// Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
......@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f)
}
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
#include <migraph/config.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -18,7 +18,7 @@ struct simplify_algebra
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -19,7 +19,7 @@ struct simplify_reshapes
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPHX_GUARD_STREAMUTILS_HPP
#define MIGRAPHX_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
#include <migraph/rank.hpp>
#include <migraph/config.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct stream_range_container
......@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail::stream_write_value_impl(rank<1>{}, os, x);
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
#include <migraph/config.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
......@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return ss.str();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
......@@ -8,12 +8,12 @@
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
#include <migraph/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
......@@ -127,6 +127,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context();
}
friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
......@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x)
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#include <migraph/shape.hpp>
#include <migraph/float_equal.hpp>
#include <migraph/requires.hpp>
#include <migraph/config.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <iostream>
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct tensor_view
......@@ -29,7 +29,7 @@ struct tensor_view
const T* data() const { return this->m_data; }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
......@@ -37,7 +37,7 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
......@@ -45,13 +45,13 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
......@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
}
template <class T>
tensor_view<T> make_view(shape s, T* data)
tensor_view<T> make_view(const shape& s, T* data)
{
return {s, data};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#define MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#include <chrono>
#include <migraph/config.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Duration, class F>
auto time(F f)
......@@ -16,7 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraph/functional.hpp>
#include <migraph/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct tracer
{
......@@ -30,7 +30,7 @@ struct tracer
std::ostream* os = nullptr;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
#include <migraph/config.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
......@@ -18,7 +18,7 @@ const std::string& get_type_name()
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe =";
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
......@@ -38,10 +38,10 @@ const std::string& get_type_name()
template <class T>
const std::string& get_type_name(const T&)
{
return migraph::get_type_name<T>();
return migraphx::get_type_name<T>();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraph/half.hpp>
#include <migraph/config.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP
#ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
......@@ -7,11 +7,11 @@
#include <iostream>
#include <numeric>
#include <migraph/float_equal.hpp>
#include <migraph/config.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range
template <class R>
......@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name,
const argument& cpu_arg,
......@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/erase.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
......@@ -70,7 +70,7 @@ bool instruction::valid() const
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
catch(migraphx::exception&)
{
return false;
}
......@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y)
{
if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments))
if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
......@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this);
}
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args)
argument instruction::eval() const
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
if(op.name() == "@literal")
{
return this->get_literal().get_argument();
}
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
}
instruction_ref instruction::get_output_alias(instruction_ref ins)
void instruction::finalize(context& ctx)
{
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs()));
if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0)
return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i));
}
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{
return op.compute_shape(compute_shapes(args));
return op.compute_shape(to_shapes(args));
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraph_onnx onnx.cpp)
set_target_properties(migraph_onnx PROPERTIES EXPORT_NAME onnx)
rocm_clang_tidy_check(migraph_onnx)
target_link_libraries(migraph_onnx PRIVATE onnx-proto)
target_link_libraries(migraph_onnx PUBLIC migraph)
add_library(migraphx_onnx onnx.cpp)
set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets(
TARGETS migraph_onnx
TARGETS migraphx_onnx
)
add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraph_onnx)
target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU)
if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx)
target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(cifar10 cifar10.cpp)
rocm_clang_tidy_check(cifar10)
target_link_libraries(cifar10 migraph_cpu migraph_gpu migraph_onnx)
target_link_libraries(cifar10 migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu)
target_link_libraries(verify_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
add_executable(perf_onnx perf_onnx.cpp)
rocm_clang_tidy_check(perf_onnx)
target_link_libraries(perf_onnx migraph_onnx migraph_cpu migraph_gpu)
target_link_libraries(perf_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
endif()
......@@ -4,12 +4,12 @@
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include "softmax.hpp"
......@@ -53,19 +53,19 @@ int main(int argc, char const* argv[])
std::string gpu_cpu = argv[1];
std::string file = argv[2];
std::string datafile = argv[3];
auto prog = migraph::parse_onnx(file);
auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl;
auto imageset = read_cifar10_images(datafile);
if(gpu_cpu == "gpu")
{
// GPU target
prog.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
prog.compile(migraphx::gpu::target{});
migraphx::program::parameter_map m;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
}
auto labels = imageset.first;
auto input = imageset.second;
......@@ -73,8 +73,8 @@ int main(int argc, char const* argv[])
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[3072 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m));
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits);
......@@ -86,15 +86,15 @@ int main(int argc, char const* argv[])
else
{
// CPU target
prog.compile(migraph::cpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
prog.compile(migraphx::cpu::target{});
auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first;
auto input = imageset.second;
auto ptr = input.data();
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraph::argument{s, &ptr[3072 * i]};
auto input3 = migraphx::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}});
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
......
......@@ -4,17 +4,20 @@
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/generate.hpp>
#include "softmax.hpp"
auto reverse_int(unsigned int i)
{
unsigned char c1, c2, c3, c4;
unsigned char c1;
unsigned char c2;
unsigned char c3;
unsigned char c4;
c1 = i & 255u;
c2 = (i >> 8u) & 255u;
c3 = (i >> 16u) & 255u;
......@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if(file.is_open())
{
int magic_number = 0, n_rows = 0, n_cols = 0;
int magic_number = 0;
int n_rows = 0;
int n_cols = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number);
......@@ -113,20 +118,20 @@ int main(int argc, char const* argv[])
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
std::string file = argv[1];
auto prog = migraph::parse_onnx(file);
auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl << std::endl;
prog.compile(migraph::gpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}};
prog.compile(migraphx::gpu::target{});
auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl;
auto ptr = input.data();
migraph::program::parameter_map m;
migraphx::program::parameter_map m;
m["output"] =
migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output")));
migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++)
{
std::cout << "label: " << labels[i] << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[784 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m));
m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]});
auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment