Commit 20b1d690 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into tests

parents 17aaaa1e ba729cfc
......@@ -11,6 +11,8 @@
#include <migraphx/context.hpp>
#include <migraphx/pass.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -34,15 +36,103 @@ struct target
* @return The context to be used during compilation and execution.
*/
context get_context() const;
/**
* @brief copy an argument to the current target.
*
* @param arg Input argument to be copied to the target
* @return Argument in the target.
*/
argument copy_to(const argument& arg) const;
/**
* @brief copy an argument from the current target.
*
* @param arg Input argument to be copied from the target
* @return Argument in the host.
*/
argument copy_from(const argument& arg) const;
/**
* @brief Allocate an argument based on the input shape
*
* @param s Shape of the argument to be allocated in the target
* @return Allocated argument in the target.
*/
argument allocate(const shape& s) const;
};
#else
template <class T>
auto target_allocate(rank<1>, T& x, const shape& s) -> decltype(x.allocate(s))
{
return x.allocate(s);
}
template <class T>
argument target_allocate(rank<0>, T& x, const shape&)
{
std::string name = x.name();
MIGRAPHX_THROW("Not computable: " + name);
}
template <class T>
argument target_allocate(T& x, const shape& s)
{
return target_allocate(rank<1>{}, x, s);
}
template <class T>
auto copy_to_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_to(arg))
{
return x.copy_to(arg);
}
template <class T>
argument copy_to_target(rank<0>, T&, const argument& arg)
{
return arg;
}
template <class T>
argument copy_to_target(T& x, const argument& arg)
{
return copy_to_target(rank<1>{}, x, arg);
}
template <class T>
auto copy_from_target(rank<1>, T& x, const argument& arg) -> decltype(x.copy_from(arg))
{
return x.copy_from(arg);
}
template <class T>
argument copy_from_target(rank<0>, T&, const argument& arg)
{
return arg;
}
template <class T>
argument copy_from_target(T& x, const argument& arg)
{
return copy_from_target(rank<1>{}, x, arg);
}
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True)
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_to_target'),
virtual('copy_from',
returns = 'argument',
input = 'const argument&',
const = True,
default = 'copy_from_target'),
virtual('allocate', s='const shape&', returns='argument', const=True,
default = 'target_allocate')
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment