"vscode:/vscode.git/clone" did not exist on "5d8601685548ea1fbfd0b290bdba910c394e5c66"
Commit 180dc7a0 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-unroll-stride

parents e535f7ef f7d987ba
...@@ -21,7 +21,10 @@ ...@@ -21,7 +21,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import string, sys, re, runpy import string
import sys
import re
import runpy
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -308,18 +311,39 @@ class Parameter: ...@@ -308,18 +311,39 @@ class Parameter:
return self.substitute('${type} ${name}', prefix=prefix) return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]: def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'&{prefix}{n}'.format(prefix=prefix or '', n=n) decl_list: List[str] = []
for t, n in self.cparams container = (container_type == "std::vector"
] or container_type == "vector")
for t, n, in self.cparams:
if not decl_list and container:
decl_list.append('{prefix}{n}.data()'.format(prefix=prefix
or '',
n=n))
else:
decl_list.append('&{prefix}{n}'.format(prefix=prefix or '',
n=n))
return decl_list
def virtual_output_declarations(self, def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]: prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'std::remove_pointer_t<{type}> {prefix}{n};'.format( container = (container_type == "std::vector"
type=Type(t).str(), prefix=prefix or '', n=n) or container_type == "vector")
for t, n in self.cparams decl_list: List[str] = []
] for t, n, in self.cparams:
if not decl_list and container:
inner_t = self.type.inner_type()
if inner_t:
decl_list.append(
'std::array<{inner_t}, 1024> {prefix}{n};'.format(
inner_t=inner_t.str(), prefix=prefix or '', n=n))
else:
decl_list.append(
'std::remove_pointer_t<{type}> {prefix}{n}'.format(
type=Type(t).str(), prefix=prefix or '', n=n))
decl_list[-1] += '=1024;' if container else ';'
return decl_list
def virtual_output(self, prefix: Optional[str] = None) -> str: def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write write = self.virtual_write
...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str: ...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str:
[c.generate() for c in cpp_classes]) [c.generate() for c in cpp_classes])
def cwrap(name: str) -> Callable: c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f): def with_cwrap(f):
type_map[name] = f type_map[name] = f
if c_type:
c_type_map[name] = Type(c_type)
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None:
# Not a generic type # Not a generic type
if not inner: if not inner:
return return
if inner.str() in c_type_map:
inner = c_type_map[inner.str()]
t = inner.add_pointer() t = inner.add_pointer()
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const(): if p.type.is_const():
...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None:
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr', p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer') 'Null pointer')
elif p.virtual:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.virtual_write = '{${name}.begin(), ${name}.begin()+${size}}; // cppcheck-suppress returnDanglingLifetime'
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})'] p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string') @cwrap('std::string', 'char*')
def string_c_wrap(p: Parameter) -> None: def string_c_wrap(p: Parameter) -> None:
t = Type('char*') t = Type('char*')
if p.returns: if p.returns:
...@@ -1061,9 +1099,9 @@ struct ${ctype} { ...@@ -1061,9 +1099,9 @@ struct ${ctype} {
c_api_virtual_impl = Template(''' c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const ${return_type} ${name}(${params}) const
{ {
${output_decls}
if (${fname} == nullptr) if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing."); throw std::runtime_error("${name} function is missing.");
${output_decls}
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\\0'; exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args}); auto api_error_result = ${fname}(${args});
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names) ...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end()); options.output_node_names = std::vector<std::string>(names.begin(), names.end());
} }
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
...@@ -265,11 +273,18 @@ struct experimental_custom_op ...@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +299,23 @@ struct custom_operation ...@@ -284,6 +299,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
...@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&) ...@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr){}
<% <%
interface('context', interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'), virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'),
virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'),
virtual('finish', returns = 'void', const = True)) %> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx) inline void migraphx_to_value(value& v, const context& ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment