Commit 9bff4331 authored by Paul's avatar Paul
Browse files

Merge

parents 214b313f 94a7f6ee
...@@ -58,14 +58,14 @@ add_library(migraphx ...@@ -58,14 +58,14 @@ add_library(migraphx
layout_nhwc.cpp layout_nhwc.cpp
load_save.cpp load_save.cpp
make_op.cpp make_op.cpp
memory_coloring.cpp
module.cpp module.cpp
msgpack.cpp msgpack.cpp
normalize_attributes.cpp normalize_attributes.cpp
normalize_ops.cpp normalize_ops.cpp
op_enums.cpp op_enums.cpp
operation.cpp operation.cpp
opt/memory_coloring.cpp optimize_module.cpp
opt/memory_coloring_impl.cpp
pad_calc.cpp pad_calc.cpp
pass_manager.cpp pass_manager.cpp
permutation.cpp permutation.cpp
...@@ -199,6 +199,7 @@ register_migraphx_ops( ...@@ -199,6 +199,7 @@ register_migraphx_ops(
scatternd_add scatternd_add
scatternd_mul scatternd_mul
scatternd_none scatternd_none
select_module
sigmoid sigmoid
sign sign
sinh sinh
...@@ -250,7 +251,10 @@ find_package(PkgConfig) ...@@ -250,7 +251,10 @@ find_package(PkgConfig)
pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3) pkg_check_modules(SQLITE3 REQUIRED IMPORTED_TARGET sqlite3)
target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3) target_link_libraries(migraphx PRIVATE PkgConfig::SQLITE3)
find_package(msgpack REQUIRED) find_package(msgpackc-cxx QUIET)
if(NOT msgpackc-cxx_FOUND)
find_package(msgpack REQUIRED)
endif()
target_link_libraries(migraphx PRIVATE msgpackc-cxx) target_link_libraries(migraphx PRIVATE msgpackc-cxx)
# Make this available to the tests # Make this available to the tests
target_link_libraries(migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>) target_link_libraries(migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>)
......
...@@ -29,7 +29,7 @@ set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) ...@@ -29,7 +29,7 @@ set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 3.0) rocm_set_soversion(migraphx_c 3.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_c TARGETS migraphx_c
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
...@@ -134,6 +133,11 @@ void set_offload_copy(compile_options& options, bool value) { options.offload_co ...@@ -134,6 +133,11 @@ void set_offload_copy(compile_options& options, bool value) { options.offload_co
void set_fast_math(compile_options& options, bool value) { options.fast_math = value; } void set_fast_math(compile_options& options, bool value) { options.fast_math = value; }
void set_exhaustive_tune_flag(compile_options& options, bool value)
{
options.exhaustive_tune = value;
}
void set_file_format(file_options& options, const char* format) { options.format = format; } void set_file_format(file_options& options, const char* format) { options.format = format; }
void set_default_dim_value(onnx_options& options, size_t value) void set_default_dim_value(onnx_options& options, size_t value)
...@@ -1690,6 +1694,19 @@ migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_option ...@@ -1690,6 +1694,19 @@ migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_option
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag(migraphx_compile_options_t compile_options,
bool value)
{
auto api_error_result = migraphx::try_([&] {
if(compile_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter compile_options: Null pointer");
migraphx::set_exhaustive_tune_flag((compile_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options) migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options)
{ {
......
...@@ -427,6 +427,10 @@ migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_opt ...@@ -427,6 +427,10 @@ migraphx_compile_options_set_offload_copy(migraphx_compile_options_t compile_opt
migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options, migraphx_status migraphx_compile_options_set_fast_math(migraphx_compile_options_t compile_options,
bool value); bool value);
migraphx_status
migraphx_compile_options_set_exhaustive_tune_flag(migraphx_compile_options_t compile_options,
bool value);
migraphx_status migraphx_status
migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options); migraphx_parse_onnx(migraphx_program_t* out, const char* name, migraphx_onnx_options_t options);
......
...@@ -1015,6 +1015,12 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -1015,6 +1015,12 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value); call(&migraphx_compile_options_set_fast_math, this->get_handle_ptr(), value);
} }
/// Set or un-set exhaustive search to find fastest kernel
void set_exhaustive_tune_flag(bool value = true)
{
call(&migraphx_compile_options_set_exhaustive_tune_flag, this->get_handle_ptr(), value);
}
}; };
/// A program represents the all computation graphs to be compiled and executed /// A program represents the all computation graphs to be compiled and executed
......
...@@ -354,6 +354,9 @@ def compile_options(h): ...@@ -354,6 +354,9 @@ def compile_options(h):
h.method('set_fast_math', h.method('set_fast_math',
api.params(value='bool'), api.params(value='bool'),
invoke='migraphx::set_fast_math($@)') invoke='migraphx::set_fast_math($@)')
h.method('set_exhaustive_tune_flag',
api.params(value='bool'),
invoke='migraphx::set_exhaustive_tune_flag($@)')
api.add_function('migraphx_parse_onnx', api.add_function('migraphx_parse_onnx',
......
...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -77,7 +77,6 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
} }
auto offset = s1.ndim() - s0.ndim(); auto offset = s1.ndim() - s0.ndim();
std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims()); std::vector<shape::dynamic_dimension> out_dims(s1.dyn_dims());
shape::dynamic_dimension one_dyn_dim{1, 1, 0};
std::transform( std::transform(
s0.dyn_dims().cbegin(), s0.dyn_dims().cbegin(),
s0.dyn_dims().cend(), s0.dyn_dims().cend(),
...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha ...@@ -88,7 +87,7 @@ std::vector<shape::dynamic_dimension> compute_broadcasted_dyn_dims(shape s0, sha
{ {
return a; return a;
} }
else if(a == one_dyn_dim or b == one_dyn_dim) else if(a == 1 or b == 1)
{ {
// setting opt to 0, may need to be changed // setting opt to 0, may need to be changed
return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0}; return shape::dynamic_dimension{std::max(a.min, b.min), std::max(a.max, b.max), 0};
......
...@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -70,9 +70,6 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
if(not fs::exists(out_path)) if(not fs::exists(out_path))
MIGRAPHX_THROW("Output file missing: " + out); MIGRAPHX_THROW("Output file missing: " + out);
if(process)
out_path = process(out_path);
return read_buffer(out_path.string()); return read_buffer(out_path.string());
} }
......
...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const ...@@ -51,8 +51,8 @@ void dead_code_elimination::apply(module& m) const
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined, // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate] // identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and not(i->name().front() == '@') and not contains({"identity", "allocate"}, i->name()) and
not contains({"undefined", "identity", "allocate"}, i->name())) not i->is_undefined())
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
......
...@@ -42,6 +42,7 @@ add_custom_command( ...@@ -42,6 +42,7 @@ add_custom_command(
) )
set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver) set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/driver)
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
rocm_install_targets( rocm_install_targets(
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "verify.hpp" #include "verify.hpp"
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
...@@ -109,8 +110,12 @@ struct loader ...@@ -109,8 +110,12 @@ struct loader
ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true)); ap(brief, {"--brief"}, ap.help("Make the output brief."), ap.set_value(true));
ap(output_type, ap(output_type,
{"--cpp"}, {"--cpp"},
ap.help("Print out the program as cpp program."), ap.help("Print out the program as C++ program."),
ap.set_value("cpp")); ap.set_value("cpp"));
ap(output_type,
{"--python", "--py"},
ap.help("Print out the program as python program."),
ap.set_value("py"));
ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json")); ap(output_type, {"--json"}, ap.help("Print out program as json."), ap.set_value("json"));
ap(output_type, ap(output_type,
{"--text"}, {"--text"},
...@@ -259,7 +264,9 @@ struct loader ...@@ -259,7 +264,9 @@ struct loader
type = "binary"; type = "binary";
} }
if(type == "cpp") if(type == "py")
p.print_py(*os);
else if(type == "cpp")
p.print_cpp(*os); p.print_cpp(*os);
else if(type == "graphviz") else if(type == "graphviz")
p.print_graph(*os, brief); p.print_graph(*os, brief);
...@@ -298,8 +305,12 @@ struct compiler_target ...@@ -298,8 +305,12 @@ struct compiler_target
{ {
#ifdef HAVE_GPU #ifdef HAVE_GPU
std::string target_name = "gpu"; std::string target_name = "gpu";
#else #elif HAVE_CPU
std::string target_name = "cpu"; std::string target_name = "cpu";
#elif HAVE_FPGA
std::string target_name = "fpga"
#else
std::string target_name = "ref"
#endif #endif
void parse(argument_parser& ap) void parse(argument_parser& ap)
...@@ -320,8 +331,7 @@ struct compiler ...@@ -320,8 +331,7 @@ struct compiler
loader l; loader l;
program_params parameters; program_params parameters;
compiler_target ct; compiler_target ct;
bool offload_copy = false; compile_options co;
bool fast_math = true;
precision quantize = precision::fp32; precision quantize = precision::fp32;
std::vector<std::string> fill0; std::vector<std::string> fill0;
...@@ -331,19 +341,26 @@ struct compiler ...@@ -331,19 +341,26 @@ struct compiler
l.parse(ap); l.parse(ap);
parameters.parse(ap); parameters.parse(ap);
ct.parse(ap); ct.parse(ap);
ap(offload_copy, ap(co.offload_copy,
{"--enable-offload-copy"}, {"--enable-offload-copy"},
ap.help("Enable implicit offload copying"), ap.help("Enable implicit offload copying"),
ap.set_value(true)); ap.set_value(true));
ap(fast_math, ap(co.fast_math,
{"--disable-fast-math"}, {"--disable-fast-math"},
ap.help("Disable fast math optimization"), ap.help("Disable fast math optimization"),
ap.set_value(false)); ap.set_value(false));
ap(co.exhaustive_tune,
{"--exhaustive-tune"},
ap.help("Exhastively search for best tuning parameters for kernels"),
ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16)); ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8)); ap(quantize, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(precision::int8));
} }
auto params(const program& p) { return parameters.generate(p, ct.get_target(), offload_copy); } auto params(const program& p)
{
return parameters.generate(p, ct.get_target(), co.offload_copy);
}
program compile() program compile()
{ {
...@@ -360,10 +377,7 @@ struct compiler ...@@ -360,10 +377,7 @@ struct compiler
{ {
quantize_int8(p, t, {params(p)}); quantize_int8(p, t, {params(p)});
} }
compile_options options; p.compile(t, co);
options.offload_copy = offload_copy;
options.fast_math = fast_math;
p.compile(t, options);
l.save(p); l.save(p);
return p; return p;
} }
...@@ -396,60 +410,41 @@ struct params : command<params> ...@@ -396,60 +410,41 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
loader l; compiler c;
program_params parameters;
compiler_target ct;
double tolerance = 80; double tolerance = 80;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
bool offload_copy = false;
bool fast_math = true;
precision quantize = precision::fp32;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
l.parse(ap); c.parse(ap);
parameters.parse(ap);
ct.parse(ap);
ap(offload_copy,
{"--enable-offload-copy"},
ap.help("Enable implicit offload copying"),
ap.set_value(true));
ap(fast_math,
{"--disable-fast-math"},
ap.help("Disable fast math optimization"),
ap.set_value(false));
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
ap.set_value(true)); ap.set_value(true));
ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true)); ap(reduce, {"-r", "--reduce"}, ap.help("Reduce program and verify"), ap.set_value(true));
ap(quantize, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(precision::fp16));
} }
void run() void run()
{ {
auto p = l.load(); auto p = c.l.load();
l.save(p); c.l.save(p);
std::cout << p << std::endl; std::cout << p << std::endl;
compile_options options; auto t = c.ct.get_target();
options.offload_copy = offload_copy; auto m = c.parameters.generate(p, t, true);
options.fast_math = fast_math;
auto t = ct.get_target();
auto m = parameters.generate(p, t, true);
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, options, quantize, tolerance); verify_instructions(p, t, c.co, c.quantize, tolerance);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, options, quantize, m, tolerance); verify_reduced_program(p, t, c.co, c.quantize, m, tolerance);
} }
else else
{ {
verify_program(l.file, p, t, options, quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, c.quantize, m, tolerance);
} }
} }
}; };
......
...@@ -108,8 +108,6 @@ target get_target(bool gpu) ...@@ -108,8 +108,6 @@ target get_target(bool gpu)
return make_target("cpu"); return make_target("cpu");
} }
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -37,7 +37,6 @@ parameter_map create_param_map(const program& p, const target& t, bool offload = ...@@ -37,7 +37,6 @@ parameter_map create_param_map(const program& p, const target& t, bool offload =
parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu); parameter_map fill_param_map(parameter_map& m, const program& p, bool gpu);
parameter_map create_param_map(const program& p, bool gpu = true); parameter_map create_param_map(const program& p, bool gpu = true);
target get_target(bool gpu); target get_target(bool gpu);
void compile_program(program& p, bool gpu = true);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "verify.hpp" #include "verify.hpp"
#include "perf.hpp" #include "perf.hpp"
#include <migraphx/ref/target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/verify_args.hpp> #include <migraphx/verify_args.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,7 +37,7 @@ inline namespace MIGRAPHX_INLINE_NS {
std::vector<argument> run_ref(program p, const parameter_map& inputs) std::vector<argument> run_ref(program p, const parameter_map& inputs)
{ {
p.compile(ref::target{}); p.compile(migraphx::make_target("ref"));
auto out = p.eval(inputs); auto out = p.eval(inputs);
std::cout << p << std::endl; std::cout << p << std::endl;
return out; return out;
......
...@@ -21,25 +21,44 @@ ...@@ -21,25 +21,44 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/manage_ptr.hpp>
#include <migraphx/dynamic_loader.hpp> #include <migraphx/dynamic_loader.hpp>
#include <migraphx/errors.hpp> #include <migraphx/errors.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp> #include <migraphx/tmp_dir.hpp>
#include <utility> #include <utility>
#include <dlfcn.h> #include <dlfcn.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void check_load_error(bool flush = false)
{
char* error_msg = dlerror();
if(not flush and error_msg != nullptr)
MIGRAPHX_THROW("Dynamic loading or symbol lookup failed with " + std::string(error_msg));
}
struct dynamic_loader_impl struct dynamic_loader_impl
{ {
dynamic_loader_impl() = default; dynamic_loader_impl() = default;
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr) dynamic_loader_impl(const fs::path& p, std::shared_ptr<tmp_dir> t = nullptr)
: handle(dlopen(p.string().c_str(), RTLD_LAZY), &dlclose), temp(std::move(t)) : handle(dlopen(p.string().c_str(), RTLD_LAZY),
manage_deleter<decltype(&dlclose), &dlclose>{}),
temp(std::move(t))
{ {
check_load_error();
} }
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size) static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{ {
auto t = std::make_shared<tmp_dir>("dloader"); auto t = std::make_shared<tmp_dir>("dloader");
...@@ -68,10 +87,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer) ...@@ -68,10 +87,11 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{ {
dlerror(); // flush any previous error messages
check_load_error(true);
void* symbol = dlsym(impl->handle.get(), name.c_str()); void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr) if(symbol == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name); check_load_error();
return {impl, symbol}; return {impl, symbol};
} }
......
...@@ -28,11 +28,8 @@ ...@@ -28,11 +28,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/serialize.hpp> #include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/pass_config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const ...@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const
"if", "if",
"loop", "loop",
"roialign", "roialign",
"nonmaxsuppression",
"scatternd_add", "scatternd_add",
"scatternd_mul", "scatternd_mul",
"scatternd_none"}; "scatternd_none"};
......
...@@ -30,23 +30,31 @@ namespace migraphx { ...@@ -30,23 +30,31 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
T generic_read_file(const std::string& filename) T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0)
{ {
std::ifstream is(filename, std::ios::binary | std::ios::ate); std::ifstream is(filename, std::ios::binary | std::ios::ate);
std::streamsize size = is.tellg(); if(nbytes == 0)
if(size < 1) {
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes = is.tellg();
if(offset > nbytes)
MIGRAPHX_THROW("offset is larger than file size");
nbytes -= offset;
}
if(nbytes < 1)
MIGRAPHX_THROW("Invalid size for: " + filename); MIGRAPHX_THROW("Invalid size for: " + filename);
is.seekg(0, std::ios::beg); is.seekg(offset, std::ios::beg);
T buffer(size, 0); T buffer(nbytes, 0);
if(not is.read(&buffer[0], size)) if(not is.read(&buffer[0], nbytes))
MIGRAPHX_THROW("Error reading file: " + filename); MIGRAPHX_THROW("Error reading file: " + filename);
return buffer; return buffer;
} }
std::vector<char> read_buffer(const std::string& filename) std::vector<char> read_buffer(const std::string& filename, size_t offset, size_t nbytes)
{ {
return generic_read_file<std::vector<char>>(filename); return generic_read_file<std::vector<char>>(filename, offset, nbytes);
} }
std::string read_string(const std::string& filename) std::string read_string(const std::string& filename)
......
...@@ -87,7 +87,7 @@ struct check_shapes ...@@ -87,7 +87,7 @@ struct check_shapes
} }
/*! /*!
* Check if the number of shape objects is equal to atleast one of the * Require the number of shape objects to equal to one of the
* given sizes. * given sizes.
* \param ns template parameter pack of sizes to check against * \param ns template parameter pack of sizes to check against
*/ */
...@@ -100,6 +100,23 @@ struct check_shapes ...@@ -100,6 +100,23 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Require the number of shape objects to equal at least a given amount. Use this
* method for ops that can take any number (variadic) of inputs.
* \param n min. number of shapes
*/
const check_shapes& has_at_least(std::size_t n) const
{
if(this->size() < n)
MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
to_string(n) + " but given " + std::to_string(size()));
return *this;
}
/*!
* Require all shapes to have the same number of elements.
* \param n number of
*/
const check_shapes& nelements(std::size_t n) const const check_shapes& nelements(std::size_t n) const
{ {
if(not this->all_of([&](const shape& s) { return s.elements() == n; })) if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
...@@ -198,7 +215,7 @@ struct check_shapes ...@@ -198,7 +215,7 @@ struct check_shapes
*/ */
const check_shapes& same_ndims() const const check_shapes& same_ndims() const
{ {
if(not this->same([](const shape& s) { return s.max_lens().size(); })) if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match"); MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this; return *this;
} }
......
...@@ -32,8 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -32,8 +32,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct compile_options struct compile_options
{ {
bool offload_copy = false; bool offload_copy = false;
bool fast_math = true; bool fast_math = true;
bool exhaustive_tune = false;
tracer trace{}; tracer trace{};
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment