Commit 01cf30d9 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

incorporate review feedback

parent 14e20a73
......@@ -42,6 +42,8 @@
#include <algorithm>
#include <cstdarg>
namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT
......@@ -51,8 +53,6 @@ extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(
}
#endif
namespace migraphx {
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
......
......@@ -105,6 +105,11 @@ struct dynamic_loader_impl
}
}
dynamic_loader_impl(const dynamic_loader_impl&) = delete;
dynamic_loader_impl& operator=(const dynamic_loader_impl&) = delete;
dynamic_loader_impl(dynamic_loader_impl&&) = default;
~dynamic_loader_impl()
{
if(handle != nullptr)
......@@ -112,6 +117,7 @@ struct dynamic_loader_impl
FreeLibrary(handle);
}
}
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = tmp_dir{"migx-dynload"};
......
......@@ -75,26 +75,28 @@ struct random_uniform
result.visit([&](auto output) {
using type = typename decltype(output)::value_type;
#ifdef _MSC_VER
// According to the C++ specification, the effect is undefined if the result type
// for the generator is not one of short, int, long, long long, unsigned short,
// unsigned int, unsigned long, or unsigned long long. See
// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
if constexpr(std::is_same_v<type, unsigned char> || std::is_same_v<type, signed char>)
if constexpr(std::is_integral<type>{})
{
std::uniform_int_distribution<int> dis{std::numeric_limits<type>::min(),
std::numeric_limits<type>::max()};
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
else
#ifdef _MSC_VER
// According to the C++ specification, the effect is undefined if the result type
// for the generator is not one of short, int, long, long long, unsigned short,
// unsigned int, unsigned long, or unsigned long long. See
// https://en.cppreference.com/w/cpp/numeric/random/uniform_int_distribution.
if constexpr(sizeof(type) == 1)
{
std::uniform_int_distribution<int> dis{std::numeric_limits<type>::min(),
std::numeric_limits<type>::max()};
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
else
#endif
if constexpr(std::is_integral<type>{})
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
{
// default range for all integer types is
// (0, std::uniform_int_distribution<type>::max()).
// Todo: enable different ranges
std::uniform_int_distribution<type> dis;
std::generate(output.begin(), output.end(), [&] { return dis(gen); });
}
}
else
{
......
......@@ -91,118 +91,28 @@ struct post_op : reflect_equality<post_op>, reflect_stream<post_op>
}
};
template <class F>
struct execute_wrapper
{
F f;
argument operator()(context&, const std::vector<argument>& args) const
{
return f(args);
}
};
template <class F>
execute_wrapper<F> make_execute_wrapper(F f)
{
return {std::move(f)};
}
template <class Derived, class Primitive>
struct dnnl_op : auto_register_op<Derived>
{
std::vector<post_op> post_ops;
std::function<argument(context& ctx, const std::vector<argument>& args)> execute;
class executable
{
std::unordered_map<int, dnnl::memory::desc> md;
Primitive prim;
std::vector<int> arg_lookup;
#ifdef _DEBUG
const dnnl_op& self;
const Derived& derived;
std::string name;
dnnl::primitive_attr prim_attr;
const std::vector<shape>& inputs;
const shape& output_shape;
#endif
public:
// clang-format off
executable(const dnnl_op& op, const shape& out_shape, const std::vector<shape>& in_shapes)
: md{op.to_memory_desc(out_shape, in_shapes)},
prim{op.get_primitive(md)},
arg_lookup{op.create_arg_map(in_shapes.size())}
#ifdef _DEBUG
, self{op},
derived{static_cast<const Derived&>(op)},
name{derived.name()},
prim_attr{op.get_primitive_attr(md)},
inputs{in_shapes},
output_shape{out_shape}
#endif
// clang-format on
{
}
argument operator()(context&, const std::vector<argument>& args)
{
#ifdef _DEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
debug_args.pop_back();
auto debug_md = self.to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name +
": Missing memory descriptor for: " + std::to_string(p.first));
if(p.second == md.at(p.first))
continue;
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - self.get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(self.post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
self.post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
}
};
template <class Self, class F>
static auto reflect_base(Self& self, F f)
{
......@@ -406,7 +316,86 @@ struct dnnl_op : auto_register_op<Derived>
{
// Compensate for allocation
inputs.pop_back();
execute = executable{*this, output_shape, inputs};
const auto& self = static_cast<const Derived&>(*this);
auto name = self.name();
auto md = to_memory_desc(output_shape, inputs);
auto prim = get_primitive(md);
auto arg_lookup = create_arg_map(inputs.size());
#ifndef NDEBUG
auto prim_attr = get_primitive_attr(md);
#endif
execute = make_execute_wrapper([=](const std::vector<argument>& args) {
#ifndef NDEBUG
// Check that the memory descriptors have not changed
auto debug_args = args;
debug_args.pop_back();
auto debug_md = to_memory_desc(output_shape, to_shapes(debug_args));
for(auto&& p : debug_md)
{
if(md.count(p.first) == 0)
MIGRAPHX_THROW(name +
": Missing memory descriptor for: " + std::to_string(p.first));
if(p.second == md.at(p.first))
continue;
MIGRAPHX_THROW(name +
": Memory descriptor has changed for: " + std::to_string(p.first));
}
// Check post_ops args are correct
auto pos = prim_attr.get_post_ops();
auto prim_input_size = inputs.size() - this->get_extra_post_op_args();
int j = 0;
for(int i = 0; i < pos.len(); i++)
{
auto arg = j + prim_input_size;
auto kind = pos.kind(i);
std::string mesg =
"Post op " + std::to_string(i) + "@" + std::to_string(arg) + ": ";
try
{
dnnl::algorithm algo;
dnnl::memory::desc mdesc;
float scale = 0;
float alpha = 0;
float beta = 0;
if(kind == dnnl::primitive::kind::binary)
{
pos.get_params_binary(i, algo, mdesc);
if(mdesc != md.at(arg_lookup.at(arg)))
MIGRAPHX_THROW(mesg +
"Memory descriptor doesn't match for binary post op");
j++;
}
else if(kind == dnnl::primitive::kind::eltwise)
{
pos.get_params_eltwise(i, scale, algo, alpha, beta);
}
else if(kind == dnnl::primitive::kind::sum)
{
pos.get_params_sum(i, scale);
algo = dnnl::algorithm::binary_add;
}
else
{
MIGRAPHX_THROW("Unknown kind");
}
if(to_dnnl_algo(post_ops[i].algo) != algo)
MIGRAPHX_THROW(mesg + "Algorithm doesn't match for post op " +
post_ops[i].algo + " != " + to_string(algo));
}
catch(const dnnl::error& e)
{
MIGRAPHX_THROW(mesg + "Failed to get post ops argument " + ": " + e.what());
}
}
#endif
std::unordered_map<int, dnnl::memory> m;
m[MIGRAPHX_DNNL_PREFIX(ARG_DST)] =
to_dnnl_memory(md.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)), args.back());
for(int i = 0; i < args.size() - 1; i++)
m[arg_lookup[i]] = to_dnnl_memory(md.at(arg_lookup[i]), args[i]);
prim.execute(get_dnnl_context().stream, m);
return args.back();
});
}
std::vector<shape> trim_post_op_inputs(const std::vector<shape>& inputs) const
{
......
......@@ -260,7 +260,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if(fs::exists(driver))
#endif
{
value v;
v["srcs"] = to_value(hsrcs);
v["params"] = to_value(params);
......
......@@ -679,10 +679,6 @@ def add_function(name: str, *args, **kwargs) -> Function:
return f
def register_functions(path: Union[Path, str]) -> None:
runpy.run_path(path if isinstance(path, str) else str(path))
def once(f: Callable) -> Any:
@wraps(f)
def decorated(*args, **kwargs):
......@@ -1286,21 +1282,17 @@ def template_eval(template, **kwargs):
return template
def invoke(path: Union[Path, str]) -> str:
def run(path: Union[Path, str]) -> str:
return template_eval(open(path).read())
def run(args: List[str]) -> None:
register_functions(args[0])
if len(args) > 1:
r = invoke(args[1])
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
runpy.run_path(sys.argv[1])
if len(sys.argv) > 2:
r = run(sys.argv[2])
sys.stdout.write(r)
else:
sys.stdout.write(generate_c_header())
sys.stdout.write(generate_c_api_body())
# sys.stdout.write(generate_cpp_header())
if __name__ == "__main__":
sys.modules['api'] = sys.modules['__main__']
run(sys.argv[1:])
......@@ -42,6 +42,8 @@
#include <algorithm>
#include <cstdarg>
namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT
......@@ -51,8 +53,6 @@ extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(
}
#endif
namespace migraphx {
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
......
......@@ -21,7 +21,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import os, sys, argparse, subprocess, te, api
import api, argparse, os, runpy, subprocess, sys, te
from pathlib import Path
clang_format_path = Path('clang-format.exe' if os.name ==
......@@ -43,12 +43,12 @@ def clang_format(buffer, **kwargs):
def api_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(api.invoke(input_path)))
f.write(clang_format(api.run(input_path)))
def te_generate(input_path: Path, output_path: Path):
with open(output_path, 'w') as f:
f.write(clang_format(te.invoke(input_path)))
f.write(clang_format(te.run(input_path)))
def main():
......@@ -66,11 +66,10 @@ def main():
return
try:
for f in [
f for f in Path('include').absolute().iterdir() if f.is_file()
]:
files = Path('include').absolute().iterdir()
for f in [f for f in files if f.is_file()]:
te_generate(f, src_dir / f'include/migraphx/{f.name}')
api.register_functions(str(migraphx_py_path))
runpy.run_path(str(migraphx_py_path))
api_generate(work_dir / 'api/migraphx.h',
src_dir / 'api/include/migraphx/migraphx.h')
print('Finished generating header migraphx.h')
......
......@@ -431,9 +431,9 @@ def template_eval(template, **kwargs):
return template
def invoke(p):
def run(p):
return template_eval(open(p).read())
if __name__ == '__main__':
sys.stdout.write(invoke(sys.argv[1]))
sys.stdout.write(run(sys.argv[1]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment