Commit 79278d88 authored by Paul's avatar Paul
Browse files

Merge

parents 3f4d78bd 10f37f49
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
......@@ -16,12 +39,24 @@
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
if(disable_exception_catch)
{
f();
}
else
{
try
{
f();
......@@ -45,6 +80,7 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return migraphx_status_unknown_error;
}
}
return migraphx_status_success;
}
......@@ -213,6 +249,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op
{
std::string name;
......@@ -237,7 +278,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
// TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
};
template <class CustomOp>
......@@ -272,6 +318,7 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
......@@ -280,23 +327,27 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
: data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
: data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{
other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr;
other.deleter = nullptr;
}
......@@ -304,6 +355,7 @@ struct manage_generic_ptr
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
......@@ -316,6 +368,7 @@ struct manage_generic_ptr
}
void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr;
D deleter = nullptr;
};
......@@ -547,23 +600,59 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
: object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute compute_f = nullptr;
migraphx::argument compute(migraphx::context ctx,
migraphx::shape output,
std::vector<migraphx::argument> inputs) const
{
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
};
......@@ -692,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
......@@ -1118,6 +1217,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_instruction_t>(
migraphx::add_allocation((module->object), (s->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] { destroy((program)); });
......@@ -1740,15 +1854,24 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h>
#include <stdbool.h>
// Add new types here
// clang-format off
......@@ -106,8 +130,18 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
......@@ -146,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
......@@ -272,6 +308,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
......@@ -452,8 +492,13 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name);
migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <cstring>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
......@@ -35,6 +59,42 @@ struct rank<0>
{
};
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
......@@ -287,13 +347,22 @@ struct interface_base : Base
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
......@@ -326,8 +395,12 @@ struct interface_base : Base
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
call(setter,
this->get_handle_ptr(),
[](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
});
}
......@@ -378,11 +451,14 @@ struct interface_base : Base
return x;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
#pragma GCC diagnostic pop
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
......@@ -441,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape
shape(migraphx_shape_datatype_t type)
......@@ -498,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py)
{
bool pout;
......@@ -505,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); }
friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
};
/**
......@@ -518,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
......@@ -542,6 +625,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout;
}
template <typename T>
std::vector<T> as_vector() const
{
size_t vector_len = this->get_shape().bytes() / sizeof(T);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len};
}
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
......@@ -556,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout;
}
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); }
friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
};
/// A target for compilation
......@@ -564,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target);
MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); }
......@@ -574,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const
{
......@@ -593,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const
{
std::vector<const char*> result(this->size());
if(!result.empty())
if(not result.empty())
{
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
}
......@@ -604,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
......@@ -631,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const
{
......@@ -650,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const
{
......@@ -669,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
......@@ -687,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts>
instructions(Ts... xs)
......@@ -706,7 +797,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts>
modules(Ts... xs)
......@@ -779,13 +870,20 @@ struct module
return instruction(ret_ins, own{});
}
instruction add_allocation(const migraphx::shape& s)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_allocation, &ret_ins, mm.get(), s.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
};
struct context
struct context : handle_lookup<context, migraphx_context>
{
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
......@@ -813,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
......@@ -837,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program);
MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const
......@@ -917,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
friend bool operator!=(const program& px, const program& py) { return not(px == py); }
};
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format
......@@ -965,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1047,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1100,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name)
{
......@@ -1125,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized
void add_op_name(const std::string& name)
......@@ -1155,6 +1253,7 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
......@@ -1164,8 +1263,12 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api
......@@ -98,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('standard', returns='bool', const=True)
@auto_handle()
......@@ -221,6 +245,10 @@ def module(h):
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_allocation',
api.params(s='const migraphx::shape&'),
invoke='migraphx::add_allocation($@)',
returns='migraphx::instruction_ref')
@auto_handle()
......@@ -412,7 +440,13 @@ def context(h):
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
inputs='std::vector<migraphx::argument>'),
returns='migraphx::argument')
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
......@@ -16,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
if(not float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
......@@ -42,7 +65,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
if(not s.dynamic() and not s.standard() and s.elements() != 0)
{
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <string>
#include <vector>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -24,9 +48,11 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name()))
continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
add_executable(driver
main.cpp
......
#include <migraphx/operators.hpp>
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto m0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx0 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
auto mx1 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
auto mx2 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
auto mx3 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
auto mx8 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
migraphx::op::convolution convolution16;
convolution16.padding = {2, 2};
convolution16.stride = {4, 4};
convolution16.dilation = {1, 1};
convolution16.group = 1;
auto mx16 = mm->add_instruction(convolution16, m0, mx15);
migraphx::op::broadcast broadcast17;
broadcast17.axis = 1;
broadcast17.broadcast_lens = {batch, 64, 55, 55};
auto mx17 = mm->add_instruction(broadcast17, mx14);
migraphx::op::add add18;
auto mx18 = mm->add_instruction(add18, mx16, mx17);
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
auto mx20 = mm->add_instruction(pooling20, mx19);
migraphx::op::convolution convolution21;
convolution21.padding = {2, 2};
convolution21.stride = {1, 1};
convolution21.dilation = {1, 1};
convolution21.group = 1;
auto mx21 = mm->add_instruction(convolution21, mx20, mx13);
migraphx::op::broadcast broadcast22;
broadcast22.axis = 1;
broadcast22.broadcast_lens = {batch, 192, 27, 27};
auto mx22 = mm->add_instruction(broadcast22, mx12);
migraphx::op::add add23;
auto mx23 = mm->add_instruction(add23, mx21, mx22);
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
auto mx25 = mm->add_instruction(pooling25, mx24);
migraphx::op::convolution convolution26;
convolution26.padding = {1, 1};
convolution26.stride = {1, 1};
convolution26.dilation = {1, 1};
convolution26.group = 1;
auto mx26 = mm->add_instruction(convolution26, mx25, mx11);
migraphx::op::broadcast broadcast27;
broadcast27.axis = 1;
broadcast27.broadcast_lens = {batch, 384, 13, 13};
auto mx27 = mm->add_instruction(broadcast27, mx10);
migraphx::op::add add28;
auto mx28 = mm->add_instruction(add28, mx26, mx27);
migraphx::op::relu relu29;
auto mx29 = mm->add_instruction(relu29, mx28);
migraphx::op::convolution convolution30;
convolution30.padding = {1, 1};
convolution30.stride = {1, 1};
convolution30.dilation = {1, 1};
convolution30.group = 1;
auto mx30 = mm->add_instruction(convolution30, mx29, mx9);
migraphx::op::broadcast broadcast31;
broadcast31.axis = 1;
broadcast31.broadcast_lens = {batch, 256, 13, 13};
auto mx31 = mm->add_instruction(broadcast31, mx8);
migraphx::op::add add32;
auto mx32 = mm->add_instruction(add32, mx30, mx31);
migraphx::op::relu relu33;
auto mx33 = mm->add_instruction(relu33, mx32);
migraphx::op::convolution convolution34;
convolution34.padding = {1, 1};
convolution34.stride = {1, 1};
convolution34.dilation = {1, 1};
convolution34.group = 1;
auto mx34 = mm->add_instruction(convolution34, mx33, mx7);
migraphx::op::broadcast broadcast35;
broadcast35.axis = 1;
broadcast35.broadcast_lens = {batch, 256, 13, 13};
auto mx35 = mm->add_instruction(broadcast35, mx6);
migraphx::op::add add36;
auto mx36 = mm->add_instruction(add36, mx34, mx35);
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
auto mx38 = mm->add_instruction(pooling38, mx37);
migraphx::op::flatten flatten39;
flatten39.axis = 1;
auto mx39 = mm->add_instruction(flatten39, mx38);
migraphx::op::identity identity40;
auto mx40 = mm->add_instruction(identity40, mx39);
migraphx::op::transpose transpose41;
transpose41.dims = {1, 0};
auto mx41 = mm->add_instruction(transpose41, mx5);
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
auto mx45 = mm->add_instruction(identity45, mx44);
migraphx::op::transpose transpose46;
transpose46.dims = {1, 0};
auto mx46 = mm->add_instruction(transpose46, mx3);
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
transpose50.dims = {1, 0};
auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
migraphx::module_ref mmain = p.get_main_module();
auto x_main_module_0 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_0 = mmain->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto x_main_module_10 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto x_main_module_12 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto x_main_module_14 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto x_main_module_16 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
auto x_main_module_18 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
"4],use_dynamic_same_auto_pad:0}"),
x_0,
x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23);
auto x_main_module_25 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_24,
x_main_module_17);
auto x_main_module_26 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28);
auto x_main_module_30 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_29,
x_main_module_15);
auto x_main_module_31 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_33,
x_main_module_13);
auto x_main_module_35 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_37,
x_main_module_11);
auto x_main_module_39 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_41);
auto x_main_module_43 =
mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto x_main_module_45 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
auto x_main_module_46 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
auto x_main_module_47 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
auto x_main_module_48 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
auto x_main_module_50 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
auto x_main_module_53 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
auto x_main_module_54 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
auto x_main_module_55 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
auto x_main_module_56 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto x_main_module_57 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
auto x_main_module_58 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
auto x_main_module_61 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
auto x_main_module_62 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto x_main_module_63 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
auto x_main_module_64 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
auto x_main_module_65 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
mmain->add_return({x_main_module_65});
return p;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGUMENT_PARSER_HPP
#include <algorithm>
#include <functional>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -16,9 +41,16 @@
#include <migraphx/requires.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/rank.hpp>
#ifndef _WIN32
#include <unistd.h>
#endif
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -51,6 +83,65 @@ template <class T>
using is_multi_value =
std::integral_constant<bool, (is_container<T>{} and not std::is_convertible<T, std::string>{})>;
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
inline std::string colorize(color c, const std::string& s)
{
std::stringstream ss;
ss << c << s << color::reset;
return ss.str();
}
template <class T>
struct type_name
{
static const std::string& apply() { return migraphx::get_type_name<T>(); }
};
template <>
struct type_name<std::string>
{
static const std::string& apply()
{
static const std::string name = "std::string";
return name;
}
};
template <class T>
struct type_name<std::vector<T>>
{
static const std::string& apply()
{
static const std::string name = "std::vector<" + type_name<T>::apply() + ">";
return name;
}
};
template <class T>
struct value_parser
{
......@@ -62,7 +153,7 @@ struct value_parser
ss.str(x);
ss >> result;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return result;
}
......@@ -74,7 +165,7 @@ struct value_parser
ss.str(x);
ss >> i;
if(ss.fail())
throw std::runtime_error("Failed to parse: " + x);
throw std::runtime_error("Failed to parse '" + x + "' as " + type_name<T>::apply());
return static_cast<T>(i);
}
......@@ -92,13 +183,42 @@ struct argument_parser
{
struct argument
{
using action_function =
std::function<bool(argument_parser&, const std::vector<std::string>&)>;
using validate_function =
std::function<void(const argument_parser&, const std::vector<std::string>&)>;
std::vector<std::string> flags;
std::function<bool(argument_parser&, const std::vector<std::string>&)> action{};
action_function action{};
std::string type = "";
std::string help = "";
std::string metavar = "";
std::string default_value = "";
std::string group = "";
unsigned nargs = 1;
bool required = false;
std::vector<validate_function> validations{};
std::string usage(const std::string& flag) const
{
std::stringstream ss;
if(flag.empty())
{
ss << metavar;
}
else
{
ss << flag;
if(not type.empty())
ss << " [" << type << "]";
}
return ss.str();
}
std::string usage() const
{
if(flags.empty())
return usage("");
return usage(flags.front());
}
};
template <class T, MIGRAPHX_REQUIRES(is_multi_value<T>{})>
......@@ -131,12 +251,14 @@ struct argument_parser
arguments.push_back({flags, [&](auto&&, const std::vector<std::string>& params) {
if(params.empty())
throw std::runtime_error("Flag with no value.");
if(not is_multi_value<T>{} and params.size() > 1)
throw std::runtime_error("Too many arguments passed.");
x = value_parser<T>::apply(params.back());
return false;
}});
argument& arg = arguments.back();
arg.type = migraphx::get_type_name<T>();
arg.type = type_name<T>::apply();
migraphx::each_args([&](auto f) { f(x, arg); }, fs...);
if(not arg.default_value.empty() and arg.nargs > 0)
arg.default_value = as_string_value(x);
......@@ -158,6 +280,11 @@ struct argument_parser
return [=](auto&&, auto& arg) { arg.nargs = n; };
}
MIGRAPHX_DRIVER_STATIC auto required()
{
return [=](auto&&, auto& arg) { arg.required = true; };
}
template <class F>
MIGRAPHX_DRIVER_STATIC auto write_action(F f)
{
......@@ -192,13 +319,141 @@ struct argument_parser
});
}
MIGRAPHX_DRIVER_STATIC auto show_help(const std::string& msg = "")
template <class F>
MIGRAPHX_DRIVER_STATIC auto validate(F f)
{
return [=](const auto& x, auto& arg) {
arg.validations.push_back(
[&, f](auto& self, const std::vector<std::string>& params) { f(self, x, params); });
};
}
MIGRAPHX_DRIVER_STATIC auto file_exist()
{
return validate([](auto&, auto&, auto& params) {
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
});
}
template <class F>
argument* find_argument(F f)
{
auto it = std::find_if(arguments.begin(), arguments.end(), f);
if(it == arguments.end())
return nullptr;
return std::addressof(*it);
}
template <class F>
bool has_argument(F f)
{
return find_argument(f) != nullptr;
}
template <class F>
std::vector<argument*> find_arguments(F f)
{
std::vector<argument*> result;
for(auto& arg : arguments)
{
if(not f(arg))
continue;
result.push_back(&arg);
}
return result;
}
std::vector<argument*> get_group_arguments(const std::string& group)
{
return find_arguments([&](const auto& arg) { return arg.group == group; });
}
std::vector<argument*> get_required_arguments()
{
return find_arguments([&](const auto& arg) { return arg.required; });
}
template <class SequenceContainer>
std::vector<std::string> get_argument_usages(SequenceContainer args)
{
std::vector<std::string> usage_flags;
std::unordered_set<std::string> found_groups;
// Remove arguments that belong to a group
auto it = std::remove_if(args.begin(), args.end(), [&](const argument* arg) {
if(arg->group.empty())
return false;
found_groups.insert(arg->group);
return true;
});
args.erase(it, args.end());
transform(found_groups, std::back_inserter(usage_flags), [&](auto&& group) {
std::vector<std::string> either_flags;
transform(get_group_arguments(group), std::back_inserter(either_flags), [](auto* arg) {
return arg->usage();
});
return "(" + join_strings(either_flags, "|") + ")";
});
transform(args, std::back_inserter(usage_flags), [&](auto* arg) { return arg->usage(); });
return usage_flags;
}
auto show_help(const std::string& msg = "")
{
return do_action([=](auto& self) {
argument* input_argument =
self.find_argument([](const auto& arg) { return arg.flags.empty(); });
auto required_usages = get_argument_usages(get_required_arguments());
if(required_usages.empty() && input_argument)
required_usages.push_back(input_argument->metavar);
required_usages.insert(required_usages.begin(), "<options>");
print_usage(required_usages);
std::cout << std::endl;
if(self.find_argument([](const auto& arg) { return arg.nargs == 0; }))
{
std::cout << color::fg_yellow << "FLAGS:" << color::reset << std::endl;
std::cout << std::endl;
for(auto&& arg : self.arguments)
{
if(arg.nargs != 0)
continue;
const int col_align = 35;
std::string prefix = " ";
int len = 0;
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
len += prefix.length() + a.length();
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
int spaces = col_align - len;
if(spaces < 0)
{
std::cout << std::endl;
}
else
{
for(int i = 0; i < spaces; i++)
std::cout << " ";
}
std::cout << arg.help << std::endl;
}
std::cout << std::endl;
}
if(self.find_argument([](const auto& arg) { return arg.nargs != 0; }))
{
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : self.arguments)
{
if(arg.nargs == 0)
continue;
std::cout << std::endl;
std::string prefix = " ";
std::cout << color::fg_green;
if(arg.flags.empty())
{
std::cout << prefix;
......@@ -210,9 +465,10 @@ struct argument_parser
std::cout << a;
prefix = ", ";
}
std::cout << color::reset;
if(not arg.type.empty())
{
std::cout << " [" << arg.type << "]";
std::cout << " [" << color::fg_blue << arg.type << color::reset << "]";
if(not arg.default_value.empty())
std::cout << " (Default: " << arg.default_value << ")";
}
......@@ -220,6 +476,7 @@ struct argument_parser
std::cout << " " << arg.help << std::endl;
}
std::cout << std::endl;
}
if(not msg.empty())
std::cout << msg << std::endl;
});
......@@ -240,6 +497,11 @@ struct argument_parser
return [=](auto&, auto& arg) { arg.type = type; };
}
MIGRAPHX_DRIVER_STATIC auto group(const std::string& group)
{
return [=](auto&, auto& arg) { arg.group = group; };
}
template <class T>
MIGRAPHX_DRIVER_STATIC auto set_value(T value)
{
......@@ -253,6 +515,109 @@ struct argument_parser
};
}
template <class T>
void set_exe_name_to(T& x)
{
actions.push_back([&](const auto& self) { x = self.exe_name; });
}
void print_try_help()
{
if(has_argument([](const auto& a) { return contains(a.flags, "--help"); }))
{
std::cout << std::endl;
std::cout << "For more information try '" << color::fg_green << "--help" << color::reset
<< "'" << std::endl;
}
}
void print_usage(const std::vector<std::string>& flags) const
{
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " ";
std::cout << join_strings(flags, " ") << std::endl;
}
auto spellcheck(const std::vector<std::string>& inputs)
{
struct result_t
{
const argument* arg = nullptr;
std::string correct = "";
std::string incorrect = "";
std::ptrdiff_t distance = std::numeric_limits<std::ptrdiff_t>::max();
};
result_t result;
for(const auto& input : inputs)
{
if(input.empty())
continue;
if(input[0] != '-')
continue;
for(const auto& arg : arguments)
{
for(const auto& flag : arg.flags)
{
if(flag.empty())
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
}
}
return result;
}
bool
run_action(const argument& arg, const std::string& flag, const std::vector<std::string>& inputs)
{
std::string msg = "";
try
{
for(const auto& v : arg.validations)
v(*this, inputs);
return arg.action(*this, inputs);
}
catch(const std::exception& e)
{
msg = e.what();
}
catch(...)
{
msg = "unknown exception";
}
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto sc = spellcheck(inputs);
if(sc.distance < 5)
{
std::cout << "Found argument '" << color::fg_yellow << sc.incorrect << color::reset
<< "'"
<< " which wasn't expected, or isn't valid in this context" << std::endl;
std::cout << " "
<< "Did you mean " << color::fg_green << sc.correct << color::reset << "?"
<< std::endl;
std::cout << std::endl;
print_usage({sc.arg->usage(sc.correct)});
}
else
{
const auto& flag_name = flag.empty() ? arg.metavar : flag;
std::cout << "Invalid input to '" << color::fg_yellow;
std::cout << arg.usage(flag_name);
std::cout << color::reset << "'" << std::endl;
std::cout << " " << msg << std::endl;
std::cout << std::endl;
print_usage({arg.usage()});
}
std::cout << std::endl;
print_try_help();
return true;
}
bool parse(std::vector<std::string> args)
{
std::unordered_map<std::string, unsigned> keywords;
......@@ -263,8 +628,11 @@ struct argument_parser
}
auto arg_map =
generic_parse(std::move(args), [&](const std::string& x) { return keywords[x]; });
std::list<const argument*> missing_arguments;
std::unordered_set<std::string> groups_used;
for(auto&& arg : arguments)
{
bool used = false;
auto flags = arg.flags;
if(flags.empty())
flags = {""};
......@@ -272,14 +640,41 @@ struct argument_parser
{
if(arg_map.count(flag) > 0)
{
if(arg.action(*this, arg_map[flag]))
if(run_action(arg, flag, arg_map[flag]))
return true;
used = true;
}
}
if(used and not arg.group.empty())
groups_used.insert(arg.group);
if(arg.required and not used)
missing_arguments.push_back(&arg);
}
// Remove arguments from a group that is being used
missing_arguments.remove_if(
[&](const argument* arg) { return groups_used.count(arg->group); });
if(not missing_arguments.empty())
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
std::cout << "The following required arguments were not provided:" << std::endl;
std::cout << " " << color::fg_red
<< join_strings(get_argument_usages(std::move(missing_arguments)), " ")
<< color::reset << std::endl;
std::cout << std::endl;
auto required_usages = get_argument_usages(get_required_arguments());
print_usage(required_usages);
print_try_help();
return true;
}
for(auto&& action : actions)
action(*this);
return false;
}
void set_exe_name(const std::string& s) { exe_name = s; }
const std::string& get_exe_name() const { return exe_name; }
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class IsKeyword>
static string_map generic_parse(std::vector<std::string> as, IsKeyword is_keyword)
......@@ -314,7 +709,9 @@ struct argument_parser
}
private:
std::vector<argument> arguments;
std::list<argument> arguments;
std::string exe_name = "";
std::vector<std::function<void(argument_parser&)>> actions;
};
} // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
......@@ -18,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
// NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m;
}
......@@ -42,10 +68,11 @@ const std::string& command_name()
}
template <class T>
void run_command(std::vector<std::string> args, bool add_help = false)
void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{
T x;
argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap);
......@@ -58,7 +85,9 @@ template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); };
m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0;
}
......
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
......@@ -50,8 +73,12 @@ struct loader
void parse(argument_parser& ap)
{
ap(file, {}, ap.metavar("<input file>"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet"));
ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
......@@ -187,6 +214,9 @@ struct loader
auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end());
}
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize)
{
migraphx::run_passes(*p.get_main_module(),
......@@ -552,26 +582,62 @@ struct onnx : command<onnx>
struct main_command
{
static std::string get_command_help()
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{
std::string result = "Commands:\n";
return std::accumulate(get_commands().begin(),
std::string result = title + "\n";
std::vector<std::string> commands(get_commands().size());
std::transform(get_commands().begin(),
get_commands().end(),
result,
[](auto r, auto&& p) { return r + " " + p.first + "\n"; });
commands.begin(),
[](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
}
void parse(argument_parser& ap)
{
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr,
{"-v", "--version"},
ap.help("Show MIGraphX version"),
ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
}
void run() {}
std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
};
} // namespace MIGRAPHX_INLINE_NS
......@@ -593,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front();
if(m.count(cmd) > 0)
{
m.at(cmd)({args.begin() + 1, args.end()});
m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
}
else
{
run_command<main_command>(args);
run_command<main_command>(argv[0], args);
}
return 0;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment