Commit 5e951d3b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup code changes

parent 0bd9d460
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -34,10 +36,92 @@ struct target ...@@ -34,10 +36,92 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief copy an argument to the current target.
*
* @param arg Input argument to be copied to the target
* @return Argument in the target.
*/
argument copy_to(const argument& arg) const;
/**
* @brief copy an argument from the current target.
*
* @param arg Input argument to be copied from the target
* @return Argument in the host.
*/
argument copy_from(const argument& arg) const;
/**
* @brief Allocate an argument based on the input shape
*
* @param s Shape of the argument to be allocated in the target
* @return Allocated argument in the target.
*/
argument allocate(const shape& s) const;
}; };
#else #else
template <class T>
auto target_allocate(rank<1>, T& x, const shape& s) -> decltype(x.allocate(s))
{
return x.allocate(s);
}
template <class T>
argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument target_allocate(T& x, const shape& s)
{
return target_allocate(rank<1>{}, x, s);
}
template <class T>
auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(arg))
{
return x.copy_to(arg);
}
template <class T>
argument copy_to_target(rank<0>, T& x, const argument&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument copy_to_target(T& x, const argument& arg)
{
return copy_to_target(rank<1>{}, x, arg);
}
template <class T>
auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_from(arg))
{
return x.copy_from(arg);
}
template <class T>
argument copy_from_target(rank<0>, T& x, const argument&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument copy_from_target(T& x, const argument& arg)
{
return copy_from_target(rank<1>{}, x, arg);
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -46,6 +130,9 @@ struct target ...@@ -46,6 +130,9 @@ struct target
* std::string name() const; * std::string name() const;
* std::vector<pass> get_passes(context& ctx) const; * std::vector<pass> get_passes(context& ctx) const;
* context get_context() const; * context get_context() const;
* argument copy_to(const argument& input) const;
* argument copy_from(const argument& input) const;
* argument allocate(const shape& s) const;
* }; * };
* *
*/ */
...@@ -125,16 +212,16 @@ struct target ...@@ -125,16 +212,16 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
argument copy_to(const argument& arg) const argument copy_to(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_to(arg); return (*this).private_detail_te_get_handle().copy_to(input);
} }
argument copy_from(const argument& arg) const argument copy_from(const argument& input) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_from(arg); return (*this).private_detail_te_get_handle().copy_from(input);
} }
argument allocate(const shape& s) const argument allocate(const shape& s) const
...@@ -159,8 +246,8 @@ struct target ...@@ -159,8 +246,8 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0; virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual argument copy_to(const argument& arg) const = 0; virtual argument copy_to(const argument& input) const = 0;
virtual argument copy_from(const argument& arg) const = 0; virtual argument copy_from(const argument& input) const = 0;
virtual argument allocate(const shape& s) const = 0; virtual argument allocate(const shape& s) const = 0;
}; };
...@@ -202,19 +289,22 @@ struct target ...@@ -202,19 +289,22 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
argument copy_to(const argument& arg) const override argument copy_to(const argument& input) const override
{ {
return private_detail_te_value.copy_to(arg);
return copy_to_target(private_detail_te_value, input);
} }
argument copy_from(const argument& arg) const override argument copy_from(const argument& input) const override
{ {
return private_detail_te_value.copy_from(arg);
return copy_from_target(private_detail_te_value, input);
} }
argument allocate(const shape& s) const override argument allocate(const shape& s) const override
{ {
return private_detail_te_value.allocate(s);
return target_allocate(private_detail_te_value, s);
} }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
...@@ -515,7 +515,8 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string> ...@@ -515,7 +515,8 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
t.copy_from(args.front()).visit([&](auto output) { auto&& arg = t.copy_from(args.front());
arg.visit([&](auto output) {
vec_val.assign(output.begin(), output.end()); vec_val.assign(output.begin(), output.end());
}); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end()); auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
......
...@@ -16,8 +16,8 @@ struct target ...@@ -16,8 +16,8 @@ struct target
std::vector<pass> get_passes(migraphx::context& ctx) const; std::vector<pass> get_passes(migraphx::context& ctx) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; } argument copy_to(const argument& arg) const { return std::move(arg); }
argument copy_from(const argument& arg) const { return arg; } argument copy_from(const argument& arg) const { return std::move(arg); }
argument allocate(const shape& s) const; argument allocate(const shape& s) const;
}; };
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -34,15 +36,114 @@ struct target ...@@ -34,15 +36,114 @@ struct target
* @return The context to be used during compilation and execution. * @return The context to be used during compilation and execution.
*/ */
context get_context() const; context get_context() const;
/**
* @brief copy an argument to the current target.
*
* @param arg Input argument to be copied to the target
* @return Argument in the target.
*/
argument copy_to(const argument& arg) const;
/**
* @brief copy an argument from the current target.
*
* @param arg Input argument to be copied from the target
* @return Argument in the host.
*/
argument copy_from(const argument& arg) const;
/**
* @brief Allocate an argument based on the input shape
*
* @param s Shape of the argument to be allocated in the target
* @return Allocated argument in the target.
*/
argument allocate(const shape& s) const;
}; };
#else #else
template <class T>
auto target_allocate(rank<1>, T& x, const shape& s)
-> decltype(x.allocate(s))
{
return x.allocate(s);
}
template <class T>
argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument target_allocate(T& x, const shape& s)
{
return target_allocate(rank<1>{}, x, s);
}
template <class T>
auto copy_to_target(rank<1>, T& x, const argument& arg)
-> decltype(x.copy_to(arg))
{
return x.copy_to(arg);
}
template <class T>
argument copy_to_target(rank<0>, T& x, const argument&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument copy_to_target(T& x, const argument& arg)
{
return copy_to_target(rank<1>{}, x, arg);
}
template <class T>
auto copy_from_target(
rank<1>, T& x, const argument& arg)
-> decltype(x.copy_from(arg))
{
return x.copy_from(arg);
}
template <class T>
argument copy_from_target(rank<0>, T& x, const argument&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
return argument{};
}
template <class T>
argument copy_from_target(T& x, const argument& arg)
{
return copy_from_target(rank<1>{}, x, arg);
}
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True) virtual('get_context', returns='context', const=True),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate', s='const shape&', returns='argument', const=True,
default = 'target_allocate')
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment