Commit 13d14c66 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Merge branch 'develop' into dyn_resize_gather

parents f4e7d9d9 d1abf06f
...@@ -38,26 +38,32 @@ ...@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/json.hpp> #include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b) extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{ {
disable_exception_catch = b; disable_exception_catch = b;
} }
#endif
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
#ifdef MIGRAPHX_BUILD_TESTING
if(disable_exception_catch) if(disable_exception_catch)
{ {
f(); f();
} }
else else
{ {
#endif
try try
{ {
f(); f();
...@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT ...@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
#ifdef MIGRAPHX_BUILD_TESTING
} }
#endif
return migraphx_status_success; return migraphx_status_success;
} }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h> #include <migraphx/api/export.h>
......
...@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe> ...@@ -66,7 +66,7 @@ template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name() std::string compute_type_name()
{ {
std::string name; std::string name;
#ifdef _MSC_VER #if defined(_MSC_VER) && !defined(__clang__)
name = typeid(PrivateMigraphTypeNameProbe).name(); name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7); name = name.substr(7);
#else #else
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -46,7 +46,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
fs::path full_path = td.path / src.path; fs::path full_path = td.path / src.path;
fs::path parent_path = full_path.parent_path(); fs::path parent_path = full_path.parent_path();
fs::create_directories(parent_path); fs::create_directories(parent_path);
write_buffer(full_path.string(), src.content.first, src.len()); write_buffer(full_path.string(), src.content.data(), src.content.size());
if(src.path.extension().string() == ".cpp") if(src.path.extension().string() == ".cpp")
{ {
params += " " + src.path.filename().string(); params += " " + src.path.filename().string();
......
...@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m, ...@@ -213,13 +213,13 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
ins->get_literal().visit([&](auto v) { ins->get_literal().visit([&](auto v) {
assert(v.size() == 1); assert(v.size() == 1);
auto x = v.front(); auto x = v.front();
if(std::isinf(x)) if(std::isinf(static_cast<double>(x)))
{ {
string_literal = "__builtin_huge_val()"; string_literal = "__builtin_huge_val()";
if(x < 0) if(x < 0)
string_literal = "-__builtin_huge_val()"; string_literal = "-__builtin_huge_val()";
} }
else if(std::isnan(x)) else if(std::isnan(static_cast<double>(x)))
string_literal = "__builtin_nan()"; string_literal = "__builtin_nan()";
else else
string_literal = ins->get_literal().to_string(); string_literal = ins->get_literal().to_string();
......
...@@ -45,7 +45,15 @@ if(NOT WIN32) ...@@ -45,7 +45,15 @@ if(NOT WIN32)
endif() endif()
rocm_clang_tidy_check(driver) rocm_clang_tidy_check(driver)
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py) file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output)
target_compile_definitions(driver PUBLIC MIGRAPHX_ORT_SHA1="${String_output}")
target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf)
if(MIGRAPHX_ENABLE_PYTHON)
target_link_libraries(driver migraphx_py)
target_compile_definitions(driver PRIVATE MIGRAPHX_ENABLE_PYTHON)
endif()
rocm_install_targets( rocm_install_targets(
TARGETS driver TARGETS driver
......
...@@ -187,6 +187,13 @@ struct value_parser ...@@ -187,6 +187,13 @@ struct value_parser
} }
}; };
// version for std::optional object
template <class T>
struct value_parser<std::optional<T>>
{
static T apply(const std::string& x) { return value_parser<T>::apply(x); }
};
struct argument_parser struct argument_parser
{ {
struct argument struct argument
......
...@@ -32,7 +32,9 @@ ...@@ -32,7 +32,9 @@
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#ifdef MIGRAPHX_ENABLE_PYTHON
#include <migraphx/py.hpp> #include <migraphx/py.hpp>
#endif
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <migraphx/load_save.hpp> #include <migraphx/load_save.hpp>
...@@ -281,10 +283,12 @@ struct loader ...@@ -281,10 +283,12 @@ struct loader
options.format = "json"; options.format = "json";
p = migraphx::load(file, options); p = migraphx::load(file, options);
} }
#ifdef MIGRAPHX_ENABLE_PYTHON
else if(file_type == "py") else if(file_type == "py")
{ {
p = migraphx::load_py(file); p = migraphx::load_py(file);
} }
#endif
else if(file_type == "migraphx") else if(file_type == "migraphx")
{ {
p = migraphx::load(file); p = migraphx::load(file);
...@@ -475,13 +479,15 @@ struct compiler ...@@ -475,13 +479,15 @@ struct compiler
{ {
if(is_offload_copy_set(p) and not co.offload_copy) if(is_offload_copy_set(p) and not co.offload_copy)
{ {
std::cout << "MIGraphX program was likely compiled with offload_copy set, Try " std::cout
"passing " << "[WARNING]: MIGraphX program was likely compiled with offload_copy "
"`--enable-offload-copy` if program run fails.\n"; "set, Try "
"passing "
"`--enable-offload-copy` if program run fails.\n";
} }
else if(co.offload_copy) else if(co.offload_copy)
{ {
std::cout << "MIGraphX program was likely compiled without " std::cout << "[WARNING]: MIGraphX program was likely compiled without "
"offload_copy set, Try " "offload_copy set, Try "
"removing " "removing "
"`--enable-offload-copy` flag if passed to driver, if program run " "`--enable-offload-copy` flag if passed to driver, if program run "
...@@ -534,13 +540,17 @@ struct params : command<params> ...@@ -534,13 +540,17 @@ struct params : command<params>
struct verify : command<verify> struct verify : command<verify>
{ {
compiler c; compiler c;
double tolerance = 80; std::optional<double> rms_tol;
std::optional<double> atol;
std::optional<double> rtol;
bool per_instruction = false; bool per_instruction = false;
bool reduce = false; bool reduce = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
c.parse(ap); c.parse(ap);
ap(tolerance, {"--tolerance"}, ap.help("Tolerance for errors")); ap(rms_tol, {"--rms-tol"}, ap.help("Tolerance for the RMS error"));
ap(atol, {"--atol"}, ap.help("Tolerance for the elementwise absolute difference"));
ap(rtol, {"--rtol"}, ap.help("Tolerance for the elementwise relative difference"));
ap(per_instruction, ap(per_instruction,
{"-i", "--per-instruction"}, {"-i", "--per-instruction"},
ap.help("Verify each instruction"), ap.help("Verify each instruction"),
...@@ -559,21 +569,30 @@ struct verify : command<verify> ...@@ -559,21 +569,30 @@ struct verify : command<verify>
auto quantize = precision::fp32; auto quantize = precision::fp32;
if(c.to_fp16) if(c.to_fp16)
{
quantize = precision::fp16; quantize = precision::fp16;
}
if(c.to_int8) if(c.to_int8)
{
quantize = precision::int8; quantize = precision::int8;
}
auto tols = get_tolerances(p, quantize, rms_tol, atol, rtol);
std::cout << "rms_tol: " << tols.rms_tol << std::endl;
std::cout << "atol: " << tols.atol << std::endl;
std::cout << "rtol: " << tols.rtol << std::endl;
if(per_instruction) if(per_instruction)
{ {
verify_instructions(p, t, c.co, quantize, tolerance); verify_instructions(p, t, c.co, quantize, tols);
} }
else if(reduce) else if(reduce)
{ {
verify_reduced_program(p, t, c.co, quantize, m, tolerance); verify_reduced_program(p, t, c.co, quantize, m, tols);
} }
else else
{ {
verify_program(c.l.file, p, t, c.co, quantize, m, tolerance); verify_program(c.l.file, p, t, c.co, quantize, m, tols);
} }
} }
}; };
...@@ -802,6 +821,13 @@ int main(int argc, const char* argv[]) ...@@ -802,6 +821,13 @@ int main(int argc, const char* argv[])
auto&& m = get_commands(); auto&& m = get_commands();
auto cmd = args.front(); auto cmd = args.front();
if(cmd == "ort-sha")
{
std::cout << MIGRAPHX_ORT_SHA1 << std::endl;
return 0;
}
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
m.at(cmd)(argv[0], {args.begin() + 1, args.end()}); m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
......
...@@ -30,11 +30,48 @@ ...@@ -30,11 +30,48 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol)
{
bool has_fp16 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::half_type); });
});
migraphx::verify::tolerance result{};
if(has_fp16 or quantize == precision::fp16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
result.rtol = 4e-2;
}
if(rms_tol)
{
result.rms_tol = *rms_tol;
}
if(atol)
{
result.atol = *atol;
}
if(rtol)
{
result.rtol = *rtol;
}
return result;
}
std::vector<argument> run_ref(program p, const parameter_map& inputs) std::vector<argument> run_ref(program p, const parameter_map& inputs)
{ {
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -76,15 +113,25 @@ void verify_program(const std::string& name, ...@@ -76,15 +113,25 @@ void verify_program(const std::string& name,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto x = run_ref(p, inputs); auto ref_outs = run_ref(p, inputs);
auto y = run_target(p, t, options, quantize, inputs); auto target_outs = run_target(p, t, options, quantize, inputs);
std::size_t output_num = x.size(); std::size_t output_num = ref_outs.size();
for(std::size_t i = 0; i < output_num; ++i) for(std::size_t i = 0; i < output_num; ++i)
{ {
verify_args(name, x[i], y[i], tolerance); if(ref_outs[i].get_shape().type() != target_outs[i].get_shape().type() or
ref_outs[i].get_shape().lens() != target_outs[i].get_shape().lens())
{
std::cout << "FAILED: " << name << std::endl;
std::cout << "Shape mismatch {" << ref_outs[i].get_shape() << "} != {"
<< target_outs[i].get_shape() << "}" << std::endl;
}
else
{
verify_args(name, target_outs[i], verify::expected{ref_outs[i]}, tols);
}
} }
} }
...@@ -92,7 +139,7 @@ void verify_instructions(const program& prog, ...@@ -92,7 +139,7 @@ void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options, compile_options options,
precision quantize, precision quantize,
double tolerance) verify::tolerance tols)
{ {
const auto* mm_prog = prog.get_main_module(); const auto* mm_prog = prog.get_main_module();
for(auto&& ins : (*mm_prog)) for(auto&& ins : (*mm_prog))
...@@ -123,8 +170,7 @@ void verify_instructions(const program& prog, ...@@ -123,8 +170,7 @@ void verify_instructions(const program& prog,
{ {
std::cout << "Verify: " << ins.name() << std::endl; std::cout << "Verify: " << ins.name() << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program( verify_program(ins.name(), p, t, options, quantize, create_param_map(p, false), tols);
ins.name(), p, t, options, quantize, create_param_map(p, false), tolerance);
} }
catch(...) catch(...)
{ {
...@@ -140,14 +186,22 @@ void verify_reduced(program p, ...@@ -140,14 +186,22 @@ void verify_reduced(program p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(mm->end(), n + 1); auto last = std::prev(mm->end(), n);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << n << std::endl; std::cout << "Verify: " << n << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, quantize, inputs, tolerance); try
{
verify_program(std::to_string(n), p, t, options, quantize, inputs, tols);
}
catch(const std::exception& e)
{
std::cout << "FAILED: " << n << std::endl;
std::cout << "Exception: " << e.what() << std::endl;
}
} }
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
...@@ -155,14 +209,20 @@ void verify_reduced_program(const program& p, ...@@ -155,14 +209,20 @@ void verify_reduced_program(const program& p,
compile_options options, compile_options options,
precision quantize, precision quantize,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) verify::tolerance tols)
{ {
const auto* mm = p.get_main_module(); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end()); auto n = std::distance(mm->begin(), mm->end());
std::cout << "Verify steps: " << n << std::endl; std::cout << "Verify steps: " << n << std::endl;
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 1; i < n; i++)
{ {
verify_reduced(p, i, t, options, quantize, inputs, tolerance); auto last = std::prev(mm->end(), i + 1);
if(contains({"@literal", "@param"}, last->name()))
{
std::cout << "Skip: " << i << std::endl;
continue;
}
verify_reduced(p, i, t, options, quantize, inputs, tols);
} }
} }
......
...@@ -26,29 +26,36 @@ ...@@ -26,29 +26,36 @@
#include "precision.hpp" #include "precision.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/verify.hpp>
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
verify::tolerance get_tolerances(const program& p,
precision quantize,
std::optional<double> rms_tol,
std::optional<double> atol,
std::optional<double> rtol);
void verify_program(const std::string& name, void verify_program(const std::string& name,
const program& p, const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 100); verify::tolerance tols = verify::tolerance{});
void verify_instructions(const program& prog, void verify_instructions(const program& prog,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
void verify_reduced_program(const program& p, void verify_reduced_program(const program& p,
const target& t, const target& t,
compile_options options = compile_options{}, compile_options options = compile_options{},
precision quantize = precision::fp32, precision quantize = precision::fp32,
const parameter_map& inputs = {}, const parameter_map& inputs = {},
double tolerance = 80); verify::tolerance tols = verify::tolerance{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
......
...@@ -27,11 +27,20 @@ ...@@ -27,11 +27,20 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp> #include <migraphx/tmp_dir.hpp>
#include <utility> #include <utility>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#else
#include <dlfcn.h> #include <dlfcn.h>
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#ifndef _WIN32
void check_load_error(bool flush = false) void check_load_error(bool flush = false)
{ {
char* error_msg = dlerror(); char* error_msg = dlerror();
...@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address) ...@@ -81,6 +90,48 @@ fs::path dynamic_loader::path(void* address)
return p; return p;
} }
#else
struct dynamic_loader_impl
{
dynamic_loader_impl() = default;
dynamic_loader_impl(const fs::path& p, tmp_dir t = {})
: handle{LoadLibrary(p.string().c_str())}, temp{std::move(t)}
{
if(handle == nullptr)
{
MIGRAPHX_THROW("Error loading DLL: " + p.string() + " (" +
std::to_string(GetLastError()) + ")");
}
}
dynamic_loader_impl(const dynamic_loader_impl&) = delete;
dynamic_loader_impl& operator=(const dynamic_loader_impl&) = delete;
dynamic_loader_impl(dynamic_loader_impl&&) = default;
~dynamic_loader_impl()
{
if(handle != nullptr)
{
FreeLibrary(handle);
}
}
static std::shared_ptr<dynamic_loader_impl> from_buffer(const char* image, std::size_t size)
{
auto t = tmp_dir{"migx-dynload"};
auto f = t.path / "tmp.dll";
write_buffer(f.string(), image, size);
return std::make_shared<dynamic_loader_impl>(f, std::move(t));
}
HMODULE handle = nullptr;
tmp_dir temp;
};
#endif
optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p) optional<dynamic_loader> dynamic_loader::try_load(const fs::path& p)
{ {
try try
...@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer) ...@@ -109,12 +160,19 @@ dynamic_loader::dynamic_loader(const std::vector<char>& buffer)
std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const std::shared_ptr<void> dynamic_loader::get_symbol(const std::string& name) const
{ {
#ifndef _WIN32
// flush any previous error messages // flush any previous error messages
check_load_error(true); check_load_error(true);
void* symbol = dlsym(impl->handle.get(), name.c_str()); void* symbol = dlsym(impl->handle.get(), name.c_str());
if(symbol == nullptr) if(symbol == nullptr)
check_load_error(); check_load_error();
return {impl, symbol}; return {impl, symbol};
#else
FARPROC addr = GetProcAddress(impl->handle, name.c_str());
if(addr == nullptr)
MIGRAPHX_THROW("Symbol not found: " + name + " (" + std::to_string(GetLastError()) + ")");
return {impl, reinterpret_cast<void*>(addr)};
#endif
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument> ...@@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
{ {
argument() = default; argument() = default;
argument(const shape& s); explicit argument(const shape& s);
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d) argument(shape s, F d)
......
...@@ -62,10 +62,9 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio ...@@ -62,10 +62,9 @@ const int auto_register<Action, T>::static_register = auto_register_action<Actio
#define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x #define MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) migraphx_auto_register_##x
#define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x) #define MIGRAPHX_AUTO_REGISTER_NAME(x) MIGRAPHX_AUTO_REGISTER_NAME_DETAIL(x)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_AUTO_REGISTER(...) \ #define MIGRAPHX_AUTO_REGISTER(...) \
void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)(migraphx::auto_register<__VA_ARGS__> x = \ [[maybe_unused]] void MIGRAPHX_AUTO_REGISTER_NAME(__LINE__)( \
migraphx::auto_register<__VA_ARGS__>{}) \ migraphx::auto_register<__VA_ARGS__> x = migraphx::auto_register<__VA_ARGS__>{});
__attribute__((unused));
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -70,13 +70,19 @@ struct check_shapes ...@@ -70,13 +70,19 @@ struct check_shapes
check_dynamic(); check_dynamic();
} }
template <class Op> template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false) check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
: begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d) : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
{ {
check_dynamic(); check_dynamic();
} }
check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
: begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
{
check_dynamic();
}
void check_dynamic() const void check_dynamic() const
{ {
if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); })) if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
...@@ -147,7 +153,7 @@ struct check_shapes ...@@ -147,7 +153,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() != n) if(begin->ndim() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -162,7 +168,7 @@ struct check_shapes ...@@ -162,7 +168,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() > n) if(begin->ndim() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -178,7 +184,7 @@ struct check_shapes ...@@ -178,7 +184,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() < n) if(begin->ndim() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -228,6 +234,16 @@ struct check_shapes ...@@ -228,6 +234,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes have the same layout.
*/
const check_shapes& same_layout() const
{
if(not this->same([](const shape& s) { return find_permutation(s); }))
MIGRAPHX_THROW(prefix() + "Layouts do not match");
return *this;
}
/*! /*!
* Check all shapes are standard. * Check all shapes are standard.
*/ */
...@@ -238,6 +254,16 @@ struct check_shapes ...@@ -238,6 +254,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are scalar.
*/
const check_shapes& scalar() const
{
if(not this->all_of([](const shape& s) { return s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar");
return *this;
}
/*! /*!
* Check all shapes are standard or scalar. * Check all shapes are standard or scalar.
*/ */
......
...@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -37,8 +37,18 @@ inline namespace MIGRAPHX_INLINE_NS {
struct src_file struct src_file
{ {
fs::path path; fs::path path;
std::pair<const char*, const char*> content; std::string_view content;
std::size_t len() const { return content.second - content.first; }
src_file() = default;
src_file(fs::path file_path, std::string_view file_content)
: path{std::move(file_path)}, content{file_content}
{
}
explicit src_file(const std::pair<std::string_view, std::string_view>& pair)
: path{pair.first}, content{pair.second}
{
}
}; };
struct MIGRAPHX_EXPORT src_compiler struct MIGRAPHX_EXPORT src_compiler
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_CONFIG_HPP #define MIGRAPHX_GUARD_CONFIG_HPP
#include <migraphx/export.h> #include <migraphx/export.h>
#include <ciso646>
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN) #if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
......
...@@ -38,12 +38,14 @@ struct dynamic_loader_impl; ...@@ -38,12 +38,14 @@ struct dynamic_loader_impl;
struct MIGRAPHX_EXPORT dynamic_loader struct MIGRAPHX_EXPORT dynamic_loader
{ {
#ifndef _WIN32
template <class T> template <class T>
static fs::path path(T* address) static fs::path path(T* address)
{ {
return path(reinterpret_cast<void*>(address)); return path(reinterpret_cast<void*>(address));
} }
static fs::path path(void* address); static fs::path path(void* address);
#endif
static optional<dynamic_loader> try_load(const fs::path& p); static optional<dynamic_loader> try_load(const fs::path& p);
......
...@@ -29,6 +29,17 @@ ...@@ -29,6 +29,17 @@
#if defined(CPPCHECK) #if defined(CPPCHECK)
#define MIGRAPHX_HAS_FILESYSTEM 1 #define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 1 #define MIGRAPHX_HAS_FILESYSTEM_TS 1
#elif defined(_WIN32)
#if _MSC_VER >= 1920
#define MIGRAPHX_HAS_FILESYSTEM 1
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#elif _MSC_VER >= 1900
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 1
#else
#define MIGRAPHX_HAS_FILESYSTEM 0
#define MIGRAPHX_HAS_FILESYSTEM_TS 0
#endif
#elif defined(__has_include) #elif defined(__has_include)
#if __has_include(<filesystem>) && __cplusplus >= 201703L #if __has_include(<filesystem>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_FILESYSTEM 1 #define MIGRAPHX_HAS_FILESYSTEM 1
......
...@@ -27,9 +27,6 @@ ...@@ -27,9 +27,6 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#ifdef _MSC_VER
#include <iso646.h>
#endif
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment