Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
...@@ -21,7 +21,10 @@ ...@@ -21,7 +21,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import string, sys, re, runpy import string
import sys
import re
import runpy
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -308,18 +311,39 @@ class Parameter: ...@@ -308,18 +311,39 @@ class Parameter:
return self.substitute('${type} ${name}', prefix=prefix) return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]: def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'&{prefix}{n}'.format(prefix=prefix or '', n=n) decl_list: List[str] = []
for t, n in self.cparams container = (container_type == "std::vector"
] or container_type == "vector")
for t, n, in self.cparams:
if not decl_list and container:
decl_list.append('{prefix}{n}.data()'.format(prefix=prefix
or '',
n=n))
else:
decl_list.append('&{prefix}{n}'.format(prefix=prefix or '',
n=n))
return decl_list
def virtual_output_declarations(self, def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]: prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'std::remove_pointer_t<{type}> {prefix}{n};'.format( container = (container_type == "std::vector"
type=Type(t).str(), prefix=prefix or '', n=n) or container_type == "vector")
for t, n in self.cparams decl_list: List[str] = []
] for t, n, in self.cparams:
if not decl_list and container:
inner_t = self.type.inner_type()
if inner_t:
decl_list.append(
'std::array<{inner_t}, 1024> {prefix}{n};'.format(
inner_t=inner_t.str(), prefix=prefix or '', n=n))
else:
decl_list.append(
'std::remove_pointer_t<{type}> {prefix}{n}'.format(
type=Type(t).str(), prefix=prefix or '', n=n))
decl_list[-1] += '=1024;' if container else ';'
return decl_list
def virtual_output(self, prefix: Optional[str] = None) -> str: def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write write = self.virtual_write
...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str: ...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str:
[c.generate() for c in cpp_classes]) [c.generate() for c in cpp_classes])
def cwrap(name: str) -> Callable: c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f): def with_cwrap(f):
type_map[name] = f type_map[name] = f
if c_type:
c_type_map[name] = Type(c_type)
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None:
# Not a generic type # Not a generic type
if not inner: if not inner:
return return
if inner.str() in c_type_map:
inner = c_type_map[inner.str()]
t = inner.add_pointer() t = inner.add_pointer()
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const(): if p.type.is_const():
...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None:
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr', p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer') 'Null pointer')
elif p.virtual:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.virtual_write = '{${name}.begin(), ${name}.begin()+${size}}; // cppcheck-suppress returnDanglingLifetime'
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})'] p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string') @cwrap('std::string', 'char*')
def string_c_wrap(p: Parameter) -> None: def string_c_wrap(p: Parameter) -> None:
t = Type('char*') t = Type('char*')
if p.returns: if p.returns:
...@@ -1061,9 +1099,9 @@ struct ${ctype} { ...@@ -1061,9 +1099,9 @@ struct ${ctype} {
c_api_virtual_impl = Template(''' c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const ${return_type} ${name}(${params}) const
{ {
${output_decls}
if (${fname} == nullptr) if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing."); throw std::runtime_error("${name} function is missing.");
${output_decls}
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\\0'; exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args}); auto api_error_result = ${fname}(${args});
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names) ...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end()); options.output_node_names = std::vector<std::string>(names.begin(), names.end());
} }
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
...@@ -265,11 +273,18 @@ struct experimental_custom_op ...@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +299,23 @@ struct custom_operation ...@@ -284,6 +299,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
...@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&) ...@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr){}
<% <%
interface('context', interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'), virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'), virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'),
virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'),
virtual('finish', returns = 'void', const = True)) %> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx) inline void migraphx_to_value(value& v, const context& ctx)
......
...@@ -68,8 +68,10 @@ struct operation ...@@ -68,8 +68,10 @@ struct operation
* *
* @param ctx This is the context created by the `target` during compilation. Implementations * @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class. * can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* `shape` of the `argument`. * For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation. * @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be * @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape. * the same the `output` shape.
...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs) ...@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs)) -> decltype(x.normalize_compute_shape(inputs))
{ {
dependent_type<operation, T> y = x; dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens()); normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs); return any_cast<T>(y).normalize_compute_shape(inputs);
} }
...@@ -558,7 +560,7 @@ lifetime get_lifetime_op(const T&) ...@@ -558,7 +560,7 @@ lifetime get_lifetime_op(const T&)
inline bool operator!=(const operation& x, const operation& y) inline bool operator!=(const operation& x, const operation& y)
{ {
return !(x == y); return not(x == y);
} }
inline value inline value
......
...@@ -37,8 +37,10 @@ ...@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp> #include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -64,12 +66,12 @@ struct target ...@@ -64,12 +66,12 @@ struct target
*/ */
context get_context() const; context get_context() const;
/** /**
* @brief Check how well an instruction is supported on a target with the given metric * @brief Get the ranges of instructions that are supported on a target
* @param ins Instruction to check if it's supported * @param module Module to check for supported instructions
* @param metric Used to define how the return value should be interpreted * @param metric Used to define how the quality of the support should be measured
* @return The value based on the chosen metric. Negative numbers mean unsupported * @return the supported segments of the graph
*/ */
float is_supported(T&, instruction_ref ins, support_metric m) const; supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/** /**
* @brief copy an argument to the current target. * @brief copy an argument to the current target.
* *
...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg) ...@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
} }
template <class T> template <class T>
float target_is_supported(T&, instruction_ref, support_metric) supported_segments target_find_supported(T&, const_module_ref, support_metric)
{ {
return 0; return {};
} }
<% <%
...@@ -125,7 +127,7 @@ interface('target', ...@@ -125,7 +127,7 @@ interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True), virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True), virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'), virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to', virtual('copy_to',
returns = 'argument', returns = 'argument',
input = 'const argument&', input = 'const argument&',
......
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import subprocess import subprocess, os
#Debug flag #Debug flag
debug = False debug = False
__repo_dir__ = os.path.normpath(
os.path.join(os.path.realpath(__file__), '..', '..'))
# Markdown code blob we should use to insert into notebook files # Markdown code blob we should use to insert into notebook files
def getipynb_markdownBlockAsList(): def getipynb_markdownBlockAsList():
...@@ -222,14 +225,15 @@ def getDelimiter(filename): ...@@ -222,14 +225,15 @@ def getDelimiter(filename):
def main(): def main():
message = open('LICENSE').read() message = open(os.path.join(__repo_dir__, 'LICENSE')).read()
#Get a list of all the files in our git repo #Get a list of all the files in our git repo
#bashCommand = "git ls-files --exclude-standard" #bashCommand = "git ls-files --exclude-standard"
#print (bashCommand.split()) #print (bashCommand.split())
proc = subprocess.run("git ls-files --exclude-standard", proc = subprocess.run("git ls-files --exclude-standard",
shell=True, shell=True,
stdout=subprocess.PIPE) stdout=subprocess.PIPE,
cwd=__repo_dir__)
fileList = proc.stdout.decode().split('\n') fileList = proc.stdout.decode().split('\n')
message = message.split('\n') message = message.split('\n')
...@@ -237,7 +241,8 @@ def main(): ...@@ -237,7 +241,8 @@ def main():
print("Target file list:\n" + str(fileList)) print("Target file list:\n" + str(fileList))
print("Output Message:\n" + str(message)) print("Output Message:\n" + str(message))
for file in fileList: for rfile in fileList:
file = os.path.join(__repo_dir__, rfile)
#print(file) #print(file)
commentDelim = getDelimiter(file) commentDelim = getDelimiter(file)
if commentDelim is not None: if commentDelim is not None:
......
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
##################################################################################### #####################################################################################
import string, sys, re import string, sys, re
trivial = ['std::size_t', 'instruction_ref', 'support_metric'] trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref'
]
headers = ''' headers = '''
#include <algorithm> #include <algorithm>
...@@ -134,7 +136,7 @@ private: ...@@ -134,7 +136,7 @@ private:
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type (PrivateDetailTypeErasedT value, private_detail_te_handle_type (PrivateDetailTypeErasedT value,
typename std::enable_if< typename std::enable_if<
!std::is_reference<PrivateDetailTypeErasedU>::value, not std::is_reference<PrivateDetailTypeErasedU>::value,
int int
>::type * = nullptr) noexcept : >::type * = nullptr) noexcept :
private_detail_te_value (std::move(value)) private_detail_te_value (std::move(value))
...@@ -176,7 +178,7 @@ private: ...@@ -176,7 +178,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle () private_detail_te_handle_base_type & private_detail_te_get_handle ()
{ {
assert(private_detail_te_handle_mem_var != nullptr); assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique()) if (not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone(); private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var; return *private_detail_te_handle_mem_var;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment