Unverified Commit 7f97b8ef authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Merge branch 'simplify_1_mul_div_ops' into divide_by_zero_check

parents 2ba401f0 d1fed367
......@@ -21,7 +21,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import string, sys, re, runpy
import string
import sys
import re
import runpy
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -308,18 +311,39 @@ class Parameter:
return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [
'&{prefix}{n}'.format(prefix=prefix or '', n=n)
for t, n in self.cparams
]
container_type = self.type.remove_generic().basic().str()
decl_list: List[str] = []
container = (container_type == "std::vector"
or container_type == "vector")
for t, n, in self.cparams:
if not decl_list and container:
decl_list.append('{prefix}{n}.data()'.format(prefix=prefix
or '',
n=n))
else:
decl_list.append('&{prefix}{n}'.format(prefix=prefix or '',
n=n))
return decl_list
def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]:
return [
'std::remove_pointer_t<{type}> {prefix}{n};'.format(
type=Type(t).str(), prefix=prefix or '', n=n)
for t, n in self.cparams
]
container_type = self.type.remove_generic().basic().str()
container = (container_type == "std::vector"
or container_type == "vector")
decl_list: List[str] = []
for t, n, in self.cparams:
if not decl_list and container:
inner_t = self.type.inner_type()
if inner_t:
decl_list.append(
'std::array<{inner_t}, 1024> {prefix}{n};'.format(
inner_t=inner_t.str(), prefix=prefix or '', n=n))
else:
decl_list.append(
'std::remove_pointer_t<{type}> {prefix}{n}'.format(
type=Type(t).str(), prefix=prefix or '', n=n))
decl_list[-1] += '=1024;' if container else ';'
return decl_list
def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write
......@@ -694,9 +718,14 @@ def generate_cpp_header() -> str:
[c.generate() for c in cpp_classes])
def cwrap(name: str) -> Callable:
c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f):
type_map[name] = f
if c_type:
c_type_map[name] = Type(c_type)
@wraps(f)
def decorated(*args, **kwargs):
......@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None:
# Not a generic type
if not inner:
return
if inner.str() in c_type_map:
inner = c_type_map[inner.str()]
t = inner.add_pointer()
if p.type.is_reference():
if p.type.is_const():
......@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None:
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
elif p.virtual:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.virtual_write = '{${name}.begin(), ${name}.begin()+${size}}; // cppcheck-suppress returnDanglingLifetime'
else:
p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer')
......@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string')
@cwrap('std::string', 'char*')
def string_c_wrap(p: Parameter) -> None:
t = Type('char*')
if p.returns:
......@@ -1061,9 +1099,9 @@ struct ${ctype} {
c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
${output_decls}
std::array<char, 256> exception_msg;
exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args});
......
......@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
......@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end());
}
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{
......@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp>
struct custom_operation
{
template <class Self, class F>
static auto reflect(Self&, F)
{
return pack();
}
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op;
std::string name() const { return op.xobject.name; }
......@@ -284,6 +299,23 @@ struct custom_operation
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
};
template <class CustomOp>
......
......@@ -26,7 +26,6 @@
#include <stdlib.h>
#include <stdbool.h>
// Add new types here
// clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
......@@ -66,12 +66,21 @@ any_ptr get_queue_context(T&)
{
return {};
}
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr){}
<%
interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('wait_for', queue = 'any_ptr', returns = 'void', default = 'wait_for_context'),
virtual('finish_on', queue = 'any_ptr', returns = 'void', default = 'finish_on_context'),
virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx)
......
......@@ -68,8 +68,10 @@ struct operation
*
* @param ctx This is the context created by the `target` during compilation. Implementations
* can use the target's `context` class rather than the `context` interface class.
* @param output This is the output shape. It is equivalent to running `compute_shape` with each
* `shape` of the `argument`.
* @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
* For a fixed shape, the returned argument will have the same shape as `output`.
* For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
* set in the dynamic shape `output`.
* @param input This is the `argument` result from the previous instruction's computation.
* @return Return an `argument` of the result computation. The `shape` of `argument` should be
* the same the `output` shape.
......@@ -137,7 +139,7 @@ auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
-> decltype(x.normalize_compute_shape(inputs))
{
dependent_type<operation, T> y = x;
normalize_attributes(y, inputs[0].lens());
normalize_attributes(y, inputs[0].max_lens());
return any_cast<T>(y).normalize_compute_shape(inputs);
}
......@@ -558,7 +560,7 @@ lifetime get_lifetime_op(const T&)
inline bool operator!=(const operation& x, const operation& y)
{
return !(x == y);
return not(x == y);
}
inline value
......
......@@ -37,8 +37,10 @@
#include <migraphx/compile_options.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/support_metric.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/supported_segments.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -64,12 +66,12 @@ struct target
*/
context get_context() const;
/**
* @brief Check how well an instruction is supported on a target with the given metric
* @param ins Instruction to check if it's supported
* @param metric Used to define how the return value should be interpreted
* @return The value based on the chosen metric. Negative numbers mean unsupported
* @brief Get the ranges of instructions that are supported on a target
* @param module Module to check for supported instructions
* @param metric Used to define how the quality of the support should be measured
* @return the supported segments of the graph
*/
float is_supported(T&, instruction_ref ins, support_metric m) const;
supported_segments target_is_supported(T&, const_module_ref mod, support_metric metric) const;
/**
* @brief copy an argument to the current target.
*
......@@ -115,9 +117,9 @@ argument copy_from_target(T&, const argument& arg)
}
template <class T>
float target_is_supported(T&, instruction_ref, support_metric)
supported_segments target_find_supported(T&, const_module_ref, support_metric)
{
return 0;
return {};
}
<%
......@@ -125,7 +127,7 @@ interface('target',
virtual('name', returns='std::string', const=True),
virtual('get_passes', ctx='context&', options='const compile_options&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True),
virtual('is_supported', returns='float', ins='instruction_ref', m='support_metric', const=True, default='target_is_supported'),
virtual('find_supported', returns='supported_segments', mod='const_module_ref', m='support_metric', const=True, default='target_find_supported'),
virtual('copy_to',
returns = 'argument',
input = 'const argument&',
......
......@@ -22,11 +22,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import subprocess
import subprocess, os
#Debug flag
debug = False
__repo_dir__ = os.path.normpath(
os.path.join(os.path.realpath(__file__), '..', '..'))
# Markdown code blob we should use to insert into notebook files
def getipynb_markdownBlockAsList():
......@@ -222,14 +225,15 @@ def getDelimiter(filename):
def main():
message = open('LICENSE').read()
message = open(os.path.join(__repo_dir__, 'LICENSE')).read()
#Get a list of all the files in our git repo
#bashCommand = "git ls-files --exclude-standard"
#print (bashCommand.split())
proc = subprocess.run("git ls-files --exclude-standard",
shell=True,
stdout=subprocess.PIPE)
stdout=subprocess.PIPE,
cwd=__repo_dir__)
fileList = proc.stdout.decode().split('\n')
message = message.split('\n')
......@@ -237,7 +241,8 @@ def main():
print("Target file list:\n" + str(fileList))
print("Output Message:\n" + str(message))
for file in fileList:
for rfile in fileList:
file = os.path.join(__repo_dir__, rfile)
#print(file)
commentDelim = getDelimiter(file)
if commentDelim is not None:
......
......@@ -23,7 +23,9 @@
#####################################################################################
import string, sys, re
trivial = ['std::size_t', 'instruction_ref', 'support_metric']
trivial = [
'std::size_t', 'instruction_ref', 'support_metric', 'const_module_ref'
]
headers = '''
#include <algorithm>
......@@ -134,7 +136,7 @@ private:
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type (PrivateDetailTypeErasedT value,
typename std::enable_if<
!std::is_reference<PrivateDetailTypeErasedU>::value,
not std::is_reference<PrivateDetailTypeErasedU>::value,
int
>::type * = nullptr) noexcept :
private_detail_te_value (std::move(value))
......@@ -176,7 +178,7 @@ private:
private_detail_te_handle_base_type & private_detail_te_get_handle ()
{
assert(private_detail_te_handle_mem_var != nullptr);
if (!private_detail_te_handle_mem_var.unique())
if (not private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment