Commit 6b4c86ab authored by Paul's avatar Paul
Browse files

Merge

parents 8dfd08e1 f7d987ba
......@@ -21,7 +21,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import string, sys, re, runpy
import string
import sys
import re
import runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -308,18 +311,39 @@ class Parameter:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
container_type = self.type.remove_generic().basic().str()
decl_list: List[str] = []
container = (container_type == "std::vector"
or container_type == "vector")
for t, n, in self.cparams:
if not decl_list and container:
decl_list.append('{prefix}{n}.data()'.format(prefix=prefix
or '',
n=n))
else:
decl_list.append('&{prefix}{n}'.format(prefix=prefix or '',
n=n))
return decl_list
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
container_type = self.type.remove_generic().basic().str()
container = (container_type == "std::vector"
or container_type == "vector")
decl_list: List[str] = []
for t, n, in self.cparams:
if not decl_list and container:
inner_t = self.type.inner_type()
if inner_t:
decl_list.append(
'std::array<{inner_t}, 1024> {prefix}{n};'.format(
inner_t=inner_t.str(), prefix=prefix or '', n=n))
else:
decl_list.append(
'std::remove_pointer_t<{type}> {prefix}{n}'.format(
type=Type(t).str(), prefix=prefix or '', n=n))
decl_list[-1] += '=1024;' if container else ';'
return decl_list
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
......@@ -694,9 +718,14 @@ def generate_cpp_header() -> str:
[c.generate() for c in cpp_classes])
def cwrap(name: str) -> Callable:
c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f):
type_map[name] = f
if c_type:
c_type_map[name] = Type(c_type)
@wraps(f)
def decorated(*args, **kwargs):
......@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None:
# Not a generic type
if not inner:
return
if inner.str() in c_type_map:
inner = c_type_map[inner.str()]
t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
......@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None:
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
elif p.virtual:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.virtual_write = '{${name}.begin(), ${name}.begin()+${size}}; // cppcheck-suppress returnDanglingLifetime'
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
......@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string')
@cwrap('std::string', 'char*')
def string_c_wrap(p: Parameter) -> None:
t = Type('char*')
if p.returns:
......@@ -1061,9 +1099,9 @@ struct ${ctype} {
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
${output_decls}
std::array<char, 256> exception_msg;
exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args});
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
......@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end());
}
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{
......@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op;
std::string name() const { return op.xobject.name; }
......@@ -284,6 +299,23 @@ struct custom_operation
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
};
template <class CustomOp>
......
......@@ -26,7 +26,6 @@
#include <stdlib.h>
#include <stdbool.h>
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
......@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&)
{
return {};
}
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr){}
<%
interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'),
virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'),
virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment