Commit 79278d88 authored by Paul's avatar Paul
Browse files

Merge

parents 3f4d78bd 10f37f49
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
...@@ -16,12 +39,24 @@ ...@@ -16,12 +39,24 @@
#include <migraphx/convert_to_json.hpp> #include <migraphx/convert_to_json.hpp>
#include <algorithm> #include <algorithm>
#include <cstdarg> #include <cstdarg>
namespace migraphx { namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F> template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT migraphx_status try_(F f, bool output = true) // NOLINT
{ {
if(disable_exception_catch)
{
f();
}
else
{
try try
{ {
f(); f();
...@@ -45,6 +80,7 @@ migraphx_status try_(F f, bool output = true) // NOLINT ...@@ -45,6 +80,7 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
} }
}
return migraphx_status_success; return migraphx_status_success;
} }
...@@ -213,6 +249,11 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -213,6 +249,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op struct experimental_custom_op
{ {
std::string name; std::string name;
...@@ -237,7 +278,12 @@ struct custom_operation ...@@ -237,7 +278,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs)); return op.compute_shape(std::move(inputs));
} }
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); } // TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
}; };
template <class CustomOp> template <class CustomOp>
...@@ -272,6 +318,7 @@ void destroy(T* x) ...@@ -272,6 +318,7 @@ void destroy(T* x)
{ {
delete x; // NOLINT delete x; // NOLINT
} }
// TODO: Move to interface preamble // TODO: Move to interface preamble
template <class C, class D> template <class C, class D>
struct manage_generic_ptr struct manage_generic_ptr
...@@ -280,23 +327,27 @@ struct manage_generic_ptr ...@@ -280,23 +327,27 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {} manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter) manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter) : data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{ {
copier(&data, pdata); copier(&data, pdata);
} }
manage_generic_ptr(const manage_generic_ptr& rhs) manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter) : data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{ {
if(copier) if(copier)
copier(&data, rhs.data); copier(&data, rhs.data);
} }
manage_generic_ptr(manage_generic_ptr&& other) noexcept manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter) : data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{ {
other.data = nullptr; other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr; other.copier = nullptr;
other.deleter = nullptr; other.deleter = nullptr;
} }
...@@ -304,6 +355,7 @@ struct manage_generic_ptr ...@@ -304,6 +355,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs) manage_generic_ptr& operator=(manage_generic_ptr rhs)
{ {
std::swap(data, rhs.data); std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier); std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter); std::swap(deleter, rhs.deleter);
return *this; return *this;
...@@ -316,6 +368,7 @@ struct manage_generic_ptr ...@@ -316,6 +368,7 @@ struct manage_generic_ptr
} }
void* data = nullptr; void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr; C copier = nullptr;
D deleter = nullptr; D deleter = nullptr;
}; };
...@@ -547,23 +600,59 @@ struct migraphx_experimental_custom_op ...@@ -547,23 +600,59 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p, migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs) Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...) : object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{ {
} }
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete> manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr; object_ptr = nullptr;
migraphx::experimental_custom_op xobject; migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute compute_f = nullptr;
migraphx::argument compute(migraphx::context ctx,
migraphx::shape output,
std::vector<migraphx::argument> inputs) const
{
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
std::remove_pointer_t<migraphx_shape_t> out; std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
auto api_error_result = std::array<char, 256> exception_msg;
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs))); exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success) if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape."); {
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object; return (&out)->object;
} }
}; };
...@@ -692,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -692,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1118,6 +1217,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou ...@@ -1118,6 +1217,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_instruction_t>(
migraphx::add_allocation((module->object), (s->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
auto api_error_result = migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
...@@ -1740,15 +1854,24 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -1740,15 +1854,24 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name) const char* name)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
*experimental_custom_op = *experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name)); allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
}); });
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input) migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H #ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H #define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
...@@ -106,8 +130,18 @@ typedef const struct migraphx_context* const_migraphx_context_t; ...@@ -106,8 +130,18 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t; typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t; typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out, typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj, void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
...@@ -146,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); ...@@ -146,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x); migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -272,6 +308,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out, ...@@ -272,6 +308,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module, migraphx_module_t module,
migraphx_instructions_t args); migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, migraphx_status migraphx_program_assign_to(migraphx_program_t output,
...@@ -452,8 +492,13 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi ...@@ -452,8 +492,13 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj, void* obj,
migraphx_experimental_custom_op_copy c, migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d, migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name); const char* name);
migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
...@@ -35,6 +59,42 @@ struct rank<0> ...@@ -35,6 +59,42 @@ struct rank<0>
{ {
}; };
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts> template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs) T* make(F f, Ts&&... xs)
{ {
...@@ -287,13 +347,22 @@ struct interface_base : Base ...@@ -287,13 +347,22 @@ struct interface_base : Base
protected: protected:
template <class F> template <class F>
static migraphx_status try_(F f) // NOLINT static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{ {
try try
{ {
f(); f();
return migraphx_status_success; return migraphx_status_success;
} }
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...) catch(...)
{ {
return migraphx_status_unknown_error; return migraphx_status_unknown_error;
...@@ -326,8 +395,12 @@ struct interface_base : Base ...@@ -326,8 +395,12 @@ struct interface_base : Base
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status { call(setter,
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); }); this->get_handle_ptr(),
[](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
}); });
} }
...@@ -378,11 +451,14 @@ struct interface_base : Base ...@@ -378,11 +451,14 @@ struct interface_base : Base
return x; return x;
} }
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template <class T> template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x}) auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{ {
return as_handle<T>{x}; return as_handle<T>{x};
} }
#pragma GCC diagnostic pop
template <class T> template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}}) auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
...@@ -441,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -441,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape); MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -498,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -498,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -505,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -505,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); } friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
}; };
/** /**
...@@ -518,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -518,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument); MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
...@@ -542,6 +625,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -542,6 +625,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
template <typename T>
std::vector<T> as_vector() const
{
size_t vector_len = this->get_shape().bytes() / sizeof(T);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len};
}
/// Generate an argument using random data /// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0) static argument generate(shape ps, size_t pseed = 0)
{ {
...@@ -556,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -556,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout; return pout;
} }
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); } friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
}; };
/// A target for compilation /// A target for compilation
...@@ -564,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -564,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target); MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -574,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -574,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const size_t size() const
{ {
...@@ -593,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -593,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const std::vector<const char*> names() const
{ {
std::vector<const char*> result(this->size()); std::vector<const char*> result(this->size());
if(!result.empty()) if(not result.empty())
{ {
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr()); call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
} }
...@@ -604,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -604,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters); MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.") MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
...@@ -631,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -631,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments); MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const size_t size() const
{ {
...@@ -650,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -650,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes); MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const size_t size() const
{ {
...@@ -669,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -669,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(operation); MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -687,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -687,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction); MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions); MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -706,7 +797,7 @@ struct module; ...@@ -706,7 +797,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules); MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
...@@ -779,13 +870,20 @@ struct module ...@@ -779,13 +870,20 @@ struct module
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
instruction add_allocation(const migraphx::shape& s)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_allocation, &ret_ins, mm.get(), s.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); } migraphx_module_t get_handle_ptr() const { return mm.get(); }
private: private:
std::shared_ptr<migraphx_module> mm; std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context : handle_lookup<context, migraphx_context>
{ {
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {} context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
...@@ -813,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -813,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options); MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -837,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -837,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program); MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -917,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -917,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return not(px == py); }
}; };
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options); MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format // set file format
...@@ -965,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -965,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options); MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1047,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1047,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options); MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1100,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1100,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1125,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1125,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options); MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
...@@ -1155,6 +1253,7 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op ...@@ -1155,6 +1253,7 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base struct experimental_custom_op_base
{ {
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default; virtual ~experimental_custom_op_base() = default;
}; };
...@@ -1164,8 +1263,12 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1164,8 +1263,12 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T> template <class T>
experimental_custom_op(T& obj) experimental_custom_op(T& obj)
{ {
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str()); this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api import api
...@@ -98,6 +121,7 @@ def shape(h): ...@@ -98,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True)
@auto_handle() @auto_handle()
...@@ -221,6 +245,10 @@ def module(h): ...@@ -221,6 +245,10 @@ def module(h):
h.method('add_return', h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'), api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref') returns='migraphx::instruction_ref')
h.method('add_allocation',
api.params(s='const migraphx::shape&'),
invoke='migraphx::add_allocation($@)',
returns='migraphx::instruction_ref')
@auto_handle() @auto_handle()
...@@ -412,7 +440,13 @@ def context(h): ...@@ -412,7 +440,13 @@ def context(h):
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op') 'migraphx::experimental_custom_op')
def experimental_custom_op(h): def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*')) h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
inputs='std::vector<migraphx::argument>'),
returns='migraphx::argument')
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
...@@ -16,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m, ...@@ -16,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0]; auto a = args[0];
auto b = args[1]; auto b = args[1];
auto input_type = a->get_shape().type(); auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0)) if(not float_equal(alpha.at<float>(0), 1.0))
{ {
auto alpha_literal = m.add_literal(alpha); auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a}); a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <unordered_map> #include <unordered_map>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -42,7 +65,7 @@ void auto_contiguous::apply(module& m) const ...@@ -42,7 +65,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last) if(ins->outputs().empty() and ins != last)
continue; continue;
shape s = ins->get_shape(); shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0) if(not s.dynamic() and not s.standard() and s.elements() != 0)
{ {
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins); auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c); m.replace_instruction(ins, c);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/compile_src.hpp> #include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp> #include <migraphx/tmp_dir.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set> #include <unordered_set>
namespace migraphx { namespace migraphx {
...@@ -24,9 +48,11 @@ void dead_code_elimination::apply(module& m) const ...@@ -24,9 +48,11 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction // Skip the last instruction
if(i == last) if(i == last)
break; break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity // Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
if(i->get_shape().elements() == 0 and i->name().front() != '@' and // identity, allocate]
i->name() != "undefined" and i->name() != "identity") if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name()))
continue; continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last)); assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited; std::unordered_set<instruction_ref> visited;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dom_info.hpp> #include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
add_executable(driver add_executable(driver
main.cpp main.cpp
......
#include <migraphx/operators.hpp>
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp" #include "models.hpp"
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); migraphx::module_ref mmain = p.get_main_module();
auto m0 = auto x_main_module_0 = mmain->add_literal(migraphx::abs(
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}}); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto mx0 = mm->add_literal( auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto mx1 = mm->add_literal( auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto mx2 = mm->add_literal( auto x_0 = mmain->add_parameter(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2)); "0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx3 = mm->add_literal( auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto mx4 = mm->add_literal( auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto mx5 = mm->add_literal( auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto mx6 = mm->add_literal( auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal( auto x_main_module_8 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto mx8 = mm->add_literal( auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal( auto x_main_module_10 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto mx10 = mm->add_literal( auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10)); migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal( auto x_main_module_12 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto mx12 = mm->add_literal( auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12)); migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal( auto x_main_module_14 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto mx14 = mm->add_literal( auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14)); migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal( auto x_main_module_16 = mmain->add_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15)); migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
migraphx::op::convolution convolution16; auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
convolution16.padding = {2, 2}; migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
convolution16.stride = {4, 4}; auto x_main_module_18 = mmain->add_literal(
convolution16.dilation = {1, 1}; migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
convolution16.group = 1; auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
auto mx16 = mm->add_instruction(convolution16, m0, mx15); migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
migraphx::op::broadcast broadcast17; auto x_main_module_20 = mmain->add_instruction(
broadcast17.axis = 1; migraphx::make_json_op("convolution",
broadcast17.broadcast_lens = {batch, 64, 55, 55}; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
auto mx17 = mm->add_instruction(broadcast17, mx14); "4],use_dynamic_same_auto_pad:0}"),
migraphx::op::add add18; x_0,
auto mx18 = mm->add_instruction(add18, mx16, mx17); x_main_module_19);
migraphx::op::relu relu19; auto x_main_module_21 = mmain->add_instruction(
auto mx19 = mm->add_instruction(relu19, mx18); migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
migraphx::op::pooling pooling20; auto x_main_module_22 =
pooling20.mode = migraphx::op::pooling_mode::max; mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
pooling20.padding = {0, 0}; auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
pooling20.stride = {2, 2}; auto x_main_module_24 = mmain->add_instruction(
pooling20.lengths = {3, 3}; migraphx::make_json_op(
auto mx20 = mm->add_instruction(pooling20, mx19); "pooling",
migraphx::op::convolution convolution21; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
convolution21.padding = {2, 2}; x_main_module_23);
convolution21.stride = {1, 1}; auto x_main_module_25 = mmain->add_instruction(
convolution21.dilation = {1, 1}; migraphx::make_json_op("convolution",
convolution21.group = 1; "{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
auto mx21 = mm->add_instruction(convolution21, mx20, mx13); "1],use_dynamic_same_auto_pad:0}"),
migraphx::op::broadcast broadcast22; x_main_module_24,
broadcast22.axis = 1; x_main_module_17);
broadcast22.broadcast_lens = {batch, 192, 27, 27}; auto x_main_module_26 = mmain->add_instruction(
auto mx22 = mm->add_instruction(broadcast22, mx12); migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
migraphx::op::add add23; auto x_main_module_27 =
auto mx23 = mm->add_instruction(add23, mx21, mx22); mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
migraphx::op::relu relu24; auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto mx24 = mm->add_instruction(relu24, mx23); auto x_main_module_29 = mmain->add_instruction(
migraphx::op::pooling pooling25; migraphx::make_json_op(
pooling25.mode = migraphx::op::pooling_mode::max; "pooling",
pooling25.padding = {0, 0}; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
pooling25.stride = {2, 2}; x_main_module_28);
pooling25.lengths = {3, 3}; auto x_main_module_30 = mmain->add_instruction(
auto mx25 = mm->add_instruction(pooling25, mx24); migraphx::make_json_op("convolution",
migraphx::op::convolution convolution26; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
convolution26.padding = {1, 1}; "1],use_dynamic_same_auto_pad:0}"),
convolution26.stride = {1, 1}; x_main_module_29,
convolution26.dilation = {1, 1}; x_main_module_15);
convolution26.group = 1; auto x_main_module_31 = mmain->add_instruction(
auto mx26 = mm->add_instruction(convolution26, mx25, mx11); migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
migraphx::op::broadcast broadcast27; auto x_main_module_32 =
broadcast27.axis = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
broadcast27.broadcast_lens = {batch, 384, 13, 13}; auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto mx27 = mm->add_instruction(broadcast27, mx10); auto x_main_module_34 = mmain->add_instruction(
migraphx::op::add add28; migraphx::make_json_op("convolution",
auto mx28 = mm->add_instruction(add28, mx26, mx27); "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
migraphx::op::relu relu29; "1],use_dynamic_same_auto_pad:0}"),
auto mx29 = mm->add_instruction(relu29, mx28); x_main_module_33,
migraphx::op::convolution convolution30; x_main_module_13);
convolution30.padding = {1, 1}; auto x_main_module_35 = mmain->add_instruction(
convolution30.stride = {1, 1}; migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
convolution30.dilation = {1, 1}; auto x_main_module_36 =
convolution30.group = 1; mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto mx30 = mm->add_instruction(convolution30, mx29, mx9); auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
migraphx::op::broadcast broadcast31; auto x_main_module_38 = mmain->add_instruction(
broadcast31.axis = 1; migraphx::make_json_op("convolution",
broadcast31.broadcast_lens = {batch, 256, 13, 13}; "{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
auto mx31 = mm->add_instruction(broadcast31, mx8); "1],use_dynamic_same_auto_pad:0}"),
migraphx::op::add add32; x_main_module_37,
auto mx32 = mm->add_instruction(add32, mx30, mx31); x_main_module_11);
migraphx::op::relu relu33; auto x_main_module_39 = mmain->add_instruction(
auto mx33 = mm->add_instruction(relu33, mx32); migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
migraphx::op::convolution convolution34; auto x_main_module_40 =
convolution34.padding = {1, 1}; mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
convolution34.stride = {1, 1}; auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
convolution34.dilation = {1, 1}; auto x_main_module_42 = mmain->add_instruction(
convolution34.group = 1; migraphx::make_json_op(
auto mx34 = mm->add_instruction(convolution34, mx33, mx7); "pooling",
migraphx::op::broadcast broadcast35; "{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
broadcast35.axis = 1; x_main_module_41);
broadcast35.broadcast_lens = {batch, 256, 13, 13}; auto x_main_module_43 =
auto mx35 = mm->add_instruction(broadcast35, mx6); mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
migraphx::op::add add36; auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto mx36 = mm->add_instruction(add36, mx34, mx35); auto x_main_module_45 = mmain->add_instruction(
migraphx::op::relu relu37; migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
auto mx37 = mm->add_instruction(relu37, mx36); auto x_main_module_46 =
migraphx::op::pooling pooling38; mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
pooling38.mode = migraphx::op::pooling_mode::max; auto x_main_module_47 = mmain->add_instruction(
pooling38.padding = {0, 0}; migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
pooling38.stride = {2, 2}; auto x_main_module_48 = mmain->add_instruction(
pooling38.lengths = {3, 3}; migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
auto mx38 = mm->add_instruction(pooling38, mx37); auto x_main_module_49 =
migraphx::op::flatten flatten39; mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
flatten39.axis = 1; auto x_main_module_50 =
auto mx39 = mm->add_instruction(flatten39, mx38); mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
migraphx::op::identity identity40; auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
auto mx40 = mm->add_instruction(identity40, mx39); auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
migraphx::op::transpose transpose41; auto x_main_module_53 = mmain->add_instruction(
transpose41.dims = {1, 0}; migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
auto mx41 = mm->add_instruction(transpose41, mx5); auto x_main_module_54 =
migraphx::op::multibroadcast multibroadcast42; mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
multibroadcast42.output_lens = {batch, 4096}; auto x_main_module_55 = mmain->add_instruction(
auto mx42 = mm->add_instruction(multibroadcast42, mx4); migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
float dot43_alpha = 1; auto x_main_module_56 = mmain->add_instruction(
float dot43_beta = 1; migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto mx43 = migraphx::add_apply_alpha_beta( auto x_main_module_57 =
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta); mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
migraphx::op::relu relu44; auto x_main_module_58 =
auto mx44 = mm->add_instruction(relu44, mx43); mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
migraphx::op::identity identity45; auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
auto mx45 = mm->add_instruction(identity45, mx44); auto x_main_module_60 = mmain->add_instruction(
migraphx::op::transpose transpose46; migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
transpose46.dims = {1, 0}; auto x_main_module_61 =
auto mx46 = mm->add_instruction(transpose46, mx3); mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
migraphx::op::multibroadcast multibroadcast47; auto x_main_module_62 = mmain->add_instruction(
multibroadcast47.output_lens = {batch, 4096}; migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto mx47 = mm->add_instruction(multibroadcast47, mx2); auto x_main_module_63 = mmain->add_instruction(
float dot48_alpha = 1; migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
float dot48_beta = 1; auto x_main_module_64 =
auto mx48 = migraphx::add_apply_alpha_beta( mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta); auto x_main_module_65 =
migraphx::op::relu relu49; mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
auto mx49 = mm->add_instruction(relu49, mx48); mmain->add_return({x_main_module_65});
migraphx::op::transpose transpose50;
transpose50.dims = {1, 0};
auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
return p; return p;
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP #define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <list>
#include <set> #include <set>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -16,9 +41,16 @@ ...@@ -16,9 +41,16 @@
#include <migraphx/requires.hpp> #include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace migraphx { namespace migraphx {
namespace driver { namespace driver {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -51,6 +83,65 @@ template <class T> ...@@ -51,6 +83,65 @@ template <class T>
using is_multi_value = using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>; std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
inline std::string colorize(color c, const std::string& s)
{
std::stringstream ss;
ss << c << s << color::reset;
return ss.str();
}
template <class T>
struct type_name
{
static const std::string& apply() { return migraphx::get_type_name<T>(); }
};
template <>
struct type_name<std::string>
{
static const std::string& apply()
{
static const std::string name = "std::string";
return name;
}
};
template <class T>
struct type_name<std::vector<T>>
{
static const std::string& apply()
{
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
return name;
}
};
template <class T> template <class T>
struct value_parser struct value_parser
{ {
...@@ -62,7 +153,7 @@ struct value_parser ...@@ -62,7 +153,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> result; ss >> result;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return result; return result;
} }
...@@ -74,7 +165,7 @@ struct value_parser ...@@ -74,7 +165,7 @@ struct value_parser
ss.str(x); ss.str(x);
ss >> i; ss >> i;
if(ss.fail()) if(ss.fail())
throw std::runtime_error("Failed to parse: " + x); throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return static_cast<T>(i); return static_cast<T>(i);
} }
...@@ -92,13 +183,42 @@ struct argument_parser ...@@ -92,13 +183,42 @@ struct argument_parser
{ {
struct argument struct argument
{ {
using action_function =
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function =
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags; std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{}; action_function action{};
std::string type = ""; std::string type = "";
std::string help = ""; std::string help = "";
std::string metavar = ""; std::string metavar = "";
std::string default_value = ""; std::string default_value = "";
std::string group = "";
unsigned nargs = 1; unsigned nargs = 1;
bool required = false;
std::vector<validate_function> validations{};
std::string usage(const std::string& flag) const
{
std::stringstream ss;
if(flag.empty())
{
ss << metavar;
}
else
{
ss << flag;
if(not type.empty())
ss << " [" << type << "]";
}
return ss.str();
}
std::string usage() const
{
if(flags.empty())
return usage("");
return usage(flags.front());
}
}; };
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})> template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
...@@ -131,12 +251,14 @@ struct argument_parser ...@@ -131,12 +251,14 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) { arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty()) if(params.empty())
throw std::runtime_error("Flag with no value."); throw std::runtime_error("Flag with no value.");
if(not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back()); x = value_parser<T>::apply(params.back());
return false; return false;
}}); }});
argument& arg = arguments.back(); argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>(); arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...); migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0) if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x); arg.default_value = as_string_value(x);
...@@ -158,6 +280,11 @@ struct argument_parser ...@@ -158,6 +280,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; }; return [=](auto&&, auto& arg) { arg.nargs = n; };
} }
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F> template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f) MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{ {
...@@ -192,13 +319,141 @@ struct argument_parser ...@@ -192,13 +319,141 @@ struct argument_parser
}); });
} }
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "") template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back(
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F>
argument* find_argument(F f)
{
auto it = std::find_if(arguments.begin(), arguments.end(), f);
if(it == arguments.end())
return nullptr;
return std::addressof(*it);
}
template <class F>
bool has_argument(F f)
{
return find_argument(f) != nullptr;
}
template <class F>
std::vector<argument*> find_arguments(F f)
{
std::vector<argument*> result;
for(auto& arg : arguments)
{
if(not f(arg))
continue;
result.push_back(&arg);
}
return result;
}
std::vector<argument*> get_group_arguments(const std::string& group)
{
return find_arguments([&](const auto& arg) { return arg.group == group; });
}
std::vector<argument*> get_required_arguments()
{
return find_arguments([&](const auto& arg) { return arg.required; });
}
template <class SequenceContainer>
std::vector<std::string> get_argument_usages(SequenceContainer args)
{
std::vector<std::string> usage_flags;
std::unordered_set<std::string> found_groups;
// Remove arguments that belong to a group
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
if(arg->group.empty())
return false;
found_groups.insert(arg->group);
return true;
});
args.erase(it, args.end());
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
std::vector<std::string> either_flags;
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
return arg->usage();
});
return "(" + join_strings(either_flags, "|") + ")";
});
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
return usage_flags;
}
auto show_help(const std::string& msg = "")
{ {
return do_action([=](auto& self) { return do_action([=](auto& self) {
argument* input_argument =
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
auto required_usages = get_argument_usages(get_required_arguments());
if(required_usages.empty() && input_argument)
required_usages.push_back(input_argument->metavar);
required_usages.insert(required_usages.begin(), "<options>");
print_usage(required_usages);
std::cout << std::endl;
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
{
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
std::cout << std::endl;
for(auto&& arg : self.arguments) for(auto&& arg : self.arguments)
{
if(arg.nargs != 0)
continue;
const int col_align = 35;
std::string prefix = " ";
int len = 0;
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
len += prefix.length() + a.length();
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
int spaces = col_align - len;
if(spaces < 0)
{ {
std::cout << std::endl; std::cout << std::endl;
}
else
{
for(int i = 0; i < spaces; i++)
std::cout << " ";
}
std::cout << arg.help << std::endl;
}
std::cout << std::endl;
}
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
{
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : self.arguments)
{
if(arg.nargs == 0)
continue;
std::cout << std::endl;
std::string prefix = " "; std::string prefix = " ";
std::cout << color::fg_green;
if(arg.flags.empty()) if(arg.flags.empty())
{ {
std::cout << prefix; std::cout << prefix;
...@@ -210,9 +465,10 @@ struct argument_parser ...@@ -210,9 +465,10 @@ struct argument_parser
std::cout << a; std::cout << a;
prefix = ", "; prefix = ", ";
} }
std::cout << color::reset;
if(not arg.type.empty()) if(not arg.type.empty())
{ {
std::cout << " [" << arg.type << "]"; std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
if(not arg.default_value.empty()) if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")"; std::cout << " (Default: " << arg.default_value << ")";
} }
...@@ -220,6 +476,7 @@ struct argument_parser ...@@ -220,6 +476,7 @@ struct argument_parser
std::cout << " " << arg.help << std::endl; std::cout << " " << arg.help << std::endl;
} }
std::cout << std::endl; std::cout << std::endl;
}
if(not msg.empty()) if(not msg.empty())
std::cout << msg << std::endl; std::cout << msg << std::endl;
}); });
...@@ -240,6 +497,11 @@ struct argument_parser ...@@ -240,6 +497,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.type = type; }; return [=](auto&, auto& arg) { arg.type = type; };
} }
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
{
return [=](auto&, auto& arg) { arg.group = group; };
}
template <class T> template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value) MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{ {
...@@ -253,6 +515,109 @@ struct argument_parser ...@@ -253,6 +515,109 @@ struct argument_parser
}; };
} }
template <class T>
void set_exe_name_to(T& x)
{
actions.push_back([&](const auto& self) { x = self.exe_name; });
}
void print_try_help()
{
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
{
std::cout << std::endl;
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
<< "'" << std::endl;
}
}
void print_usage(const std::vector<std::string>& flags) const
{
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " ";
std::cout << join_strings(flags, " ") << std::endl;
}
auto spellcheck(const std::vector<std::string>& inputs)
{
struct result_t
{
const argument* arg = nullptr;
std::string correct = "";
std::string incorrect = "";
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
};
result_t result;
for(const auto& input : inputs)
{
if(input.empty())
continue;
if(input[0] != '-')
continue;
for(const auto& arg : arguments)
{
for(const auto& flag : arg.flags)
{
if(flag.empty())
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
}
}
return result;
}
bool
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
{
std::string msg = "";
try
{
for(const auto& v : arg.validations)
v(*this, inputs);
return arg.action(*this, inputs);
}
catch(const std::exception& e)
{
msg = e.what();
}
catch(...)
{
msg = "unknown exception";
}
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto sc = spellcheck(inputs);
if(sc.distance < 5)
{
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
<< "'"
<< " which wasn't expected, or isn't valid in this context" << std::endl;
std::cout << " "
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
<< std::endl;
std::cout << std::endl;
print_usage({sc.arg->usage(sc.correct)});
}
else
{
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << arg.usage(flag_name);
std::cout << color::reset << "'" << std::endl;
std::cout << " " << msg << std::endl;
std::cout << std::endl;
print_usage({arg.usage()});
}
std::cout << std::endl;
print_try_help();
return true;
}
bool parse(std::vector<std::string> args) bool parse(std::vector<std::string> args)
{ {
std::unordered_map<std::string, unsigned> keywords; std::unordered_map<std::string, unsigned> keywords;
...@@ -263,8 +628,11 @@ struct argument_parser ...@@ -263,8 +628,11 @@ struct argument_parser
} }
auto arg_map = auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; }); generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
std::list<const argument*> missing_arguments;
std::unordered_set<std::string> groups_used;
for(auto&& arg : arguments) for(auto&& arg : arguments)
{ {
bool used = false;
auto flags = arg.flags; auto flags = arg.flags;
if(flags.empty()) if(flags.empty())
flags = {""}; flags = {""};
...@@ -272,14 +640,41 @@ struct argument_parser ...@@ -272,14 +640,41 @@ struct argument_parser
{ {
if(arg_map.count(flag) > 0) if(arg_map.count(flag) > 0)
{ {
if(arg.action(*this, arg_map[flag])) if(run_action(arg, flag, arg_map[flag]))
return true; return true;
used = true;
} }
} }
if(used and not arg.group.empty())
groups_used.insert(arg.group);
if(arg.required and not used)
missing_arguments.push_back(&arg);
} }
// Remove arguments from a group that is being used
missing_arguments.remove_if(
[&](const argument* arg) { return groups_used.count(arg->group); });
if(not missing_arguments.empty())
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
std::cout << "The following required arguments were not provided:" << std::endl;
std::cout << " " << color::fg_red
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
<< color::reset << std::endl;
std::cout << std::endl;
auto required_usages = get_argument_usages(get_required_arguments());
print_usage(required_usages);
print_try_help();
return true;
}
for(auto&& action : actions)
action(*this);
return false; return false;
} }
void set_exe_name(const std::string& s) { exe_name = s; }
const std::string& get_exe_name() const { return exe_name; }
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword> template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword) static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
...@@ -314,7 +709,9 @@ struct argument_parser ...@@ -314,7 +709,9 @@ struct argument_parser
} }
private: private:
std::vector<argument> arguments; std::list<argument> arguments;
std::string exe_name = "";
std::vector<std::function<void(argument_parser&)>> actions;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP #define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
...@@ -18,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -18,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands() inline auto& get_commands()
{ {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m; static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m; return m;
} }
...@@ -42,10 +68,11 @@ const std::string& command_name() ...@@ -42,10 +68,11 @@ const std::string& command_name()
} }
template <class T> template <class T>
void run_command(std::vector<std::string> args, bool add_help = false) void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{ {
T x; T x;
argument_parser ap; argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help) if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help()); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap); x.parse(ap);
...@@ -58,7 +85,9 @@ template <class T> ...@@ -58,7 +85,9 @@ template <class T>
int auto_register_command() int auto_register_command()
{ {
auto& m = get_commands(); auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); }; m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0; return 0;
} }
......
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify.hpp" #include "verify.hpp"
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
...@@ -50,8 +73,12 @@ struct loader ...@@ -50,8 +73,12 @@ struct loader
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(file, {}, ap.metavar("<input file>")); ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet")); ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx")); ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx")); ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
...@@ -187,6 +214,9 @@ struct loader ...@@ -187,6 +214,9 @@ struct loader
auto last = std::prev(mm->end(), trim); auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end()); mm->remove_instructions(last, mm->end());
} }
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize) if(optimize)
{ {
migraphx::run_passes(*p.get_main_module(), migraphx::run_passes(*p.get_main_module(),
...@@ -552,26 +582,62 @@ struct onnx : command<onnx> ...@@ -552,26 +582,62 @@ struct onnx : command<onnx>
struct main_command struct main_command
{ {
static std::string get_command_help() static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{ {
std::string result = "Commands:\n"; std::string result = title + "\n";
return std::accumulate(get_commands().begin(), std::vector<std::string> commands(get_commands().size());
std::transform(get_commands().begin(),
get_commands().end(), get_commands().end(),
result, commands.begin(),
[](auto r, auto&& p) { return r + " " + p.first + "\n"; }); [](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
} }
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) + std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR); "." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help())); ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr, ap(nullptr,
{"-v", "--version"}, {"-v", "--version"},
ap.help("Show MIGraphX version"), ap.help("Show MIGraphX version"),
ap.show_help(version_str)); ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
} }
void run() {} std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
...@@ -593,11 +659,11 @@ int main(int argc, const char* argv[]) ...@@ -593,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front(); auto cmd = args.front();
if(m.count(cmd) > 0) if(m.count(cmd) > 0)
{ {
m.at(cmd)({args.begin() + 1, args.end()}); m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
} }
else else
{ {
run_command<main_command>(args); run_command<main_command>(argv[0], args);
} }
return 0; return 0;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "marker_roctx.hpp" #include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp> #include <migraphx/dynamic_loader.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment