Commit 79278d88 authored by Paul's avatar Paul
Browse files

Merge

parents 3f4d78bd 10f37f49
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/rank.hpp>
#include <migraphx/shape.hpp>
......@@ -16,34 +39,47 @@
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
try
if(disable_exception_catch)
{
f();
}
catch(const migraphx::exception& ex)
else
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
try
{
f();
}
catch(const migraphx::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
if(ex.error > 0)
return migraphx_status(ex.error);
else
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(const std::exception& ex)
{
if(output)
std::cerr << "MIGraphX Error: " << ex.what() << std::endl;
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
return migraphx_status_success;
}
......@@ -213,6 +249,11 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::instruction_ref add_allocation(module& m, const migraphx::shape& s)
{
return m.add_instruction(migraphx::make_op("allocate", {{"shape", migraphx::to_value(s)}}), {});
}
struct experimental_custom_op
{
std::string name;
......@@ -237,7 +278,12 @@ struct custom_operation
return op.compute_shape(std::move(inputs));
}
argument compute(const std::vector<argument>&) const { MIGRAPHX_THROW("Not computable"); }
// TODO: Compute method with module_args
argument
compute(migraphx::context ctx, migraphx::shape output_shape, std::vector<argument> inputs) const
{
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
}
};
template <class CustomOp>
......@@ -272,6 +318,7 @@ void destroy(T* x)
{
delete x; // NOLINT
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
......@@ -280,30 +327,35 @@ struct manage_generic_ptr
manage_generic_ptr(std::nullptr_t) {}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
manage_generic_ptr(void* pdata, const char* obj_tname, C pcopier, D pdeleter)
: data(nullptr), obj_typename(obj_tname), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
: data(nullptr), obj_typename(rhs.obj_typename), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
: data(other.data),
obj_typename(other.obj_typename),
copier(other.copier),
deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
other.data = nullptr;
other.obj_typename = "";
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(obj_typename, rhs.obj_typename);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
......@@ -315,9 +367,10 @@ struct manage_generic_ptr
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
void* data = nullptr;
const char* obj_typename = "";
C copier = nullptr;
D deleter = nullptr;
};
extern "C" struct migraphx_shape;
......@@ -547,23 +600,59 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op(void* p,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
: object_ptr(p, obj_typename, c, d), xobject(std::forward<Ts>(xs)...)
{
}
manage_generic_ptr<migraphx_experimental_custom_op_copy, migraphx_experimental_custom_op_delete>
object_ptr = nullptr;
migraphx::experimental_custom_op xobject;
migraphx_experimental_custom_op_compute compute_f = nullptr;
migraphx::argument compute(migraphx::context ctx,
migraphx::shape output,
std::vector<migraphx::argument> inputs) const
{
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr)
throw std::runtime_error("compute function is missing.");
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_context_t>(&(ctx)),
object_cast<migraphx_shape_t>(&(output)),
object_cast<migraphx_arguments_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing.");
auto api_error_result =
compute_shape_f(&out, object_ptr.data, object_cast<migraphx_shapes_t>(&(inputs)));
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
throw std::runtime_error("Error in compute_shape.");
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in compute_shape of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return (&out)->object;
}
};
......@@ -692,6 +781,16 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).standard();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{
auto api_error_result = migraphx::try_([&] { destroy((argument)); });
......@@ -1118,6 +1217,21 @@ extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* ou
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(s == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter s: Null pointer");
*out = allocate<migraphx_instruction_t>(
migraphx::add_allocation((module->object), (s->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] { destroy((program)); });
......@@ -1740,15 +1854,24 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
*experimental_custom_op =
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (name));
allocate<migraphx_experimental_custom_op_t>((obj), (c), (d), (obj_typename), (name));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input)
{
auto api_error_result = migraphx::try_([&] { (obj)->compute_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input)
{
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#define MIGRAPHX_GUARD_C_API_MIGRAPHX_H
#include <stdlib.h>
#include <stdbool.h>
// Add new types here
// clang-format off
......@@ -106,8 +130,18 @@ typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute)(migraphx_argument_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_context_t ctx,
migraphx_shape_t output,
migraphx_arguments_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
......@@ -146,6 +180,8 @@ migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status
migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_shape_t x);
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
......@@ -272,6 +308,10 @@ migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_allocation(migraphx_instruction_t* out,
migraphx_module_t module,
const_migraphx_shape_t s);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
......@@ -452,8 +492,13 @@ migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experi
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* obj_typename,
const char* name);
migraphx_status
migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_compute input);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <cstring>
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
......@@ -35,6 +59,42 @@ struct rank<0>
{
};
template <class PrivateMigraphTypeNameProbe>
std::string compute_type_name()
{
std::string name;
#ifdef _MSC_VER
name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7);
#else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__;
auto begin = name.find(parameter_name) + sizeof(parameter_name);
#if(defined(__GNUC__) && !defined(__clang__) && __GNUC__ == 4 && __GNUC_MINOR__ < 7)
auto length = name.find_last_of(",") - begin;
#else
auto length = name.find_first_of("];", begin) - begin;
#endif
name = name.substr(begin, length);
#endif
return name;
}
template <class T>
const std::string& get_type_name()
{
static const std::string name = compute_type_name<T>();
return name;
}
template <class T>
const std::string& get_type_name(const T&)
{
return get_type_name<T>();
}
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
......@@ -287,13 +347,22 @@ struct interface_base : Base
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
static migraphx_status try_(F f, char* ex_msg = nullptr, size_t ex_msg_size = 0) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(const std::exception& ex)
{
if(ex_msg)
{
std::strncpy(ex_msg, ex.what(), ex_msg_size);
ex_msg[ex_msg_size - 1] = '\0';
}
return migraphx_status_unknown_error;
}
catch(...)
{
return migraphx_status_unknown_error;
......@@ -326,9 +395,13 @@ struct interface_base : Base
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
call(setter,
this->get_handle_ptr(),
[](auto out, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<1>{}, f, out, obj, xs...); }, ex_msg, ex_msg_size);
});
}
template <class T, class Setter, class F>
......@@ -378,11 +451,14 @@ struct interface_base : Base
return x;
}
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
#pragma GCC diagnostic pop
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
......@@ -441,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
MIGRAPHX_HANDLE_CONSTRUCTOR(shape)
/// Construct a scalar shape
shape(migraphx_shape_datatype_t type)
......@@ -498,6 +574,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
bool standard() const
{
bool result = false;
call(&migraphx_shape_standard, &result, this->get_handle_ptr());
return result;
}
friend bool operator==(const shape& px, const shape& py)
{
bool pout;
......@@ -505,7 +588,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout;
}
friend bool operator!=(const shape& px, const shape& py) { return !(px == py); }
friend bool operator!=(const shape& px, const shape& py) { return not(px == py); }
};
/**
......@@ -518,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
argument() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
MIGRAPHX_HANDLE_CONSTRUCTOR(argument)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
......@@ -542,6 +625,14 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout;
}
template <typename T>
std::vector<T> as_vector() const
{
size_t vector_len = this->get_shape().bytes() / sizeof(T);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len};
}
/// Generate an argument using random data
static argument generate(shape ps, size_t pseed = 0)
{
......@@ -556,7 +647,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
return pout;
}
friend bool operator!=(const argument& px, const argument& py) { return !(px == py); }
friend bool operator!=(const argument& px, const argument& py) { return not(px == py); }
};
/// A target for compilation
......@@ -564,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
target() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(target);
MIGRAPHX_HANDLE_CONSTRUCTOR(target)
/// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); }
......@@ -574,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
program_parameter_shapes() {}
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes)
size_t size() const
{
......@@ -593,7 +684,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
std::vector<const char*> names() const
{
std::vector<const char*> result(this->size());
if(!result.empty())
if(not result.empty())
{
call(&migraphx_program_parameter_shapes_names, result.data(), this->get_handle_ptr());
}
......@@ -604,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters)
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
......@@ -631,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
MIGRAPHX_HANDLE_CONSTRUCTOR(arguments)
size_t size() const
{
......@@ -650,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
MIGRAPHX_HANDLE_CONSTRUCTOR(shapes)
size_t size() const
{
......@@ -669,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
MIGRAPHX_HANDLE_CONSTRUCTOR(operation)
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
......@@ -687,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
MIGRAPHX_HANDLE_CONSTRUCTOR(instruction)
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions)
template <class... Ts>
instructions(Ts... xs)
......@@ -706,7 +797,7 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
MIGRAPHX_HANDLE_CONSTRUCTOR(modules)
template <class... Ts>
modules(Ts... xs)
......@@ -779,13 +870,20 @@ struct module
return instruction(ret_ins, own{});
}
instruction add_allocation(const migraphx::shape& s)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_allocation, &ret_ins, mm.get(), s.get_handle_ptr());
return instruction(ret_ins, own{});
}
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
};
struct context
struct context : handle_lookup<context, migraphx_context>
{
context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
......@@ -813,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
compile_options() { this->make_handle(&migraphx_compile_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options)
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
......@@ -837,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() { this->make_handle(&migraphx_program_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(program);
MIGRAPHX_HANDLE_CONSTRUCTOR(program)
/// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const
......@@ -917,13 +1015,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu, this->share_handle()};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
friend bool operator!=(const program& px, const program& py) { return not(px == py); }
};
// options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options)
file_options() { this->make_handle(&migraphx_file_options_create); }
// set file format
......@@ -965,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options)
/// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1047,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
tf_options() { this->make_handle(&migraphx_tf_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options)
/// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
......@@ -1100,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names)
void add(const std::string& name)
{
......@@ -1125,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options)
/// Add an operator that should be quantized
void add_op_name(const std::string& name)
......@@ -1154,9 +1252,10 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
......@@ -1164,8 +1263,12 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
this->make_interface(&migraphx_experimental_custom_op_create,
obj,
get_type_name(obj).c_str(),
obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import api
......@@ -98,6 +121,7 @@ def shape(h):
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('standard', returns='bool', const=True)
@auto_handle()
......@@ -221,6 +245,10 @@ def module(h):
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_allocation',
api.params(s='const migraphx::shape&'),
invoke='migraphx::add_allocation($@)',
returns='migraphx::instruction_ref')
@auto_handle()
......@@ -412,7 +440,13 @@ def context(h):
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.constructor('create',
api.params(obj_typename='const char*', name='const char*'))
h.virtual('compute',
api.params(ctx='migraphx::context',
output='migraphx::shape',
inputs='std::vector<migraphx::argument>'),
returns='migraphx::argument')
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
......@@ -16,7 +39,7 @@ instruction_ref insert_apply_alpha_beta(module& m,
auto a = args[0];
auto b = args[1];
auto input_type = a->get_shape().type();
if(!float_equal(alpha.at<float>(0), 1.0))
if(not float_equal(alpha.at<float>(0), 1.0))
{
auto alpha_literal = m.add_literal(alpha);
a = insert_common_op(m, pos, migraphx::make_op("mul"), {alpha_literal, a});
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
......@@ -42,7 +65,7 @@ void auto_contiguous::apply(module& m) const
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
if(not s.dynamic() and not s.standard() and s.elements() != 0)
{
auto c = m.insert_instruction(std::next(ins), make_op("contiguous"), ins);
m.replace_instruction(ins, c);
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/common.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/compile_src.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/tmp_dir.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <algorithm>
#include <string>
#include <vector>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp>
#include <migraphx/operation.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -24,9 +48,11 @@ void dead_code_elimination::apply(module& m) const
// Skip the last instruction
if(i == last)
break;
// Skip instruction with empty shape as output unless its a builtin or undefined or identity
if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity")
// Skip instruction with empty shape as output unless its [dynamic, builtin, undefined,
// identity, allocate]
if((not i->get_shape().dynamic() and i->get_shape().elements() == 0) and
i->name().front() != '@' and
not contains({"undefined", "identity", "allocate"}, i->name()))
continue;
assert(std::distance(m.begin(), i) <= std::distance(m.begin(), last));
std::unordered_set<instruction_ref> visited;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
......
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
add_executable(driver
main.cpp
......
#include <migraphx/operators.hpp>
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/apply_alpha_beta.hpp>
#include "models.hpp"
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto m0 =
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto mx0 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 0));
auto mx1 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 1));
auto mx2 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 2));
auto mx3 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 3));
auto mx4 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 4));
auto mx5 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 5));
auto mx6 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 6));
auto mx7 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 7));
auto mx8 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 8));
auto mx9 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 9));
auto mx10 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 10));
auto mx11 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 11));
auto mx12 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 12));
auto mx13 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 13));
auto mx14 = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 14));
auto mx15 = mm->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 15));
migraphx::op::convolution convolution16;
convolution16.padding = {2, 2};
convolution16.stride = {4, 4};
convolution16.dilation = {1, 1};
convolution16.group = 1;
auto mx16 = mm->add_instruction(convolution16, m0, mx15);
migraphx::op::broadcast broadcast17;
broadcast17.axis = 1;
broadcast17.broadcast_lens = {batch, 64, 55, 55};
auto mx17 = mm->add_instruction(broadcast17, mx14);
migraphx::op::add add18;
auto mx18 = mm->add_instruction(add18, mx16, mx17);
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
auto mx20 = mm->add_instruction(pooling20, mx19);
migraphx::op::convolution convolution21;
convolution21.padding = {2, 2};
convolution21.stride = {1, 1};
convolution21.dilation = {1, 1};
convolution21.group = 1;
auto mx21 = mm->add_instruction(convolution21, mx20, mx13);
migraphx::op::broadcast broadcast22;
broadcast22.axis = 1;
broadcast22.broadcast_lens = {batch, 192, 27, 27};
auto mx22 = mm->add_instruction(broadcast22, mx12);
migraphx::op::add add23;
auto mx23 = mm->add_instruction(add23, mx21, mx22);
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
auto mx25 = mm->add_instruction(pooling25, mx24);
migraphx::op::convolution convolution26;
convolution26.padding = {1, 1};
convolution26.stride = {1, 1};
convolution26.dilation = {1, 1};
convolution26.group = 1;
auto mx26 = mm->add_instruction(convolution26, mx25, mx11);
migraphx::op::broadcast broadcast27;
broadcast27.axis = 1;
broadcast27.broadcast_lens = {batch, 384, 13, 13};
auto mx27 = mm->add_instruction(broadcast27, mx10);
migraphx::op::add add28;
auto mx28 = mm->add_instruction(add28, mx26, mx27);
migraphx::op::relu relu29;
auto mx29 = mm->add_instruction(relu29, mx28);
migraphx::op::convolution convolution30;
convolution30.padding = {1, 1};
convolution30.stride = {1, 1};
convolution30.dilation = {1, 1};
convolution30.group = 1;
auto mx30 = mm->add_instruction(convolution30, mx29, mx9);
migraphx::op::broadcast broadcast31;
broadcast31.axis = 1;
broadcast31.broadcast_lens = {batch, 256, 13, 13};
auto mx31 = mm->add_instruction(broadcast31, mx8);
migraphx::op::add add32;
auto mx32 = mm->add_instruction(add32, mx30, mx31);
migraphx::op::relu relu33;
auto mx33 = mm->add_instruction(relu33, mx32);
migraphx::op::convolution convolution34;
convolution34.padding = {1, 1};
convolution34.stride = {1, 1};
convolution34.dilation = {1, 1};
convolution34.group = 1;
auto mx34 = mm->add_instruction(convolution34, mx33, mx7);
migraphx::op::broadcast broadcast35;
broadcast35.axis = 1;
broadcast35.broadcast_lens = {batch, 256, 13, 13};
auto mx35 = mm->add_instruction(broadcast35, mx6);
migraphx::op::add add36;
auto mx36 = mm->add_instruction(add36, mx34, mx35);
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
auto mx38 = mm->add_instruction(pooling38, mx37);
migraphx::op::flatten flatten39;
flatten39.axis = 1;
auto mx39 = mm->add_instruction(flatten39, mx38);
migraphx::op::identity identity40;
auto mx40 = mm->add_instruction(identity40, mx39);
migraphx::op::transpose transpose41;
transpose41.dims = {1, 0};
auto mx41 = mm->add_instruction(transpose41, mx5);
migraphx::op::multibroadcast multibroadcast42;
multibroadcast42.output_lens = {batch, 4096};
auto mx42 = mm->add_instruction(multibroadcast42, mx4);
float dot43_alpha = 1;
float dot43_beta = 1;
auto mx43 = migraphx::add_apply_alpha_beta(
*mm, {mx40, mx41, mx42}, migraphx::make_op("dot"), dot43_alpha, dot43_beta);
migraphx::op::relu relu44;
auto mx44 = mm->add_instruction(relu44, mx43);
migraphx::op::identity identity45;
auto mx45 = mm->add_instruction(identity45, mx44);
migraphx::op::transpose transpose46;
transpose46.dims = {1, 0};
auto mx46 = mm->add_instruction(transpose46, mx3);
migraphx::op::multibroadcast multibroadcast47;
multibroadcast47.output_lens = {batch, 4096};
auto mx47 = mm->add_instruction(multibroadcast47, mx2);
float dot48_alpha = 1;
float dot48_beta = 1;
auto mx48 = migraphx::add_apply_alpha_beta(
*mm, {mx45, mx46, mx47}, migraphx::make_op("dot"), dot48_alpha, dot48_beta);
migraphx::op::relu relu49;
auto mx49 = mm->add_instruction(relu49, mx48);
migraphx::op::transpose transpose50;
transpose50.dims = {1, 0};
auto mx50 = mm->add_instruction(transpose50, mx1);
migraphx::op::multibroadcast multibroadcast51;
multibroadcast51.output_lens = {batch, 1000};
auto mx51 = mm->add_instruction(multibroadcast51, mx0);
float dot52_alpha = 1;
float dot52_beta = 1;
migraphx::add_apply_alpha_beta(
*mm, {mx49, mx50, mx51}, migraphx::make_op("dot"), dot52_alpha, dot52_beta);
migraphx::module_ref mmain = p.get_main_module();
auto x_main_module_0 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 0)));
auto x_main_module_1 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 1)));
auto x_main_module_2 = mmain->add_literal(migraphx::abs(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1}}, 2)));
auto x_0 = mmain->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {batch, 3, 224, 224}});
auto x_main_module_4 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000}}, 3));
auto x_main_module_5 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 4096}}, 4));
auto x_main_module_6 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 5));
auto x_main_module_7 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 4096}}, 6));
auto x_main_module_8 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096}}, 7));
auto x_main_module_9 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {4096, 9216}}, 8));
auto x_main_module_10 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 9));
auto x_main_module_11 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 256, 3, 3}}, 10));
auto x_main_module_12 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {256}}, 11));
auto x_main_module_13 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {256, 384, 3, 3}}, 12));
auto x_main_module_14 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {384}}, 13));
auto x_main_module_15 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {384, 192, 3, 3}}, 14));
auto x_main_module_16 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {192}}, 15));
auto x_main_module_17 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {192, 64, 5, 5}}, 16));
auto x_main_module_18 = mmain->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {64}}, 17));
auto x_main_module_19 = mmain->add_literal(migraphx::generate_literal(
migraphx::shape{migraphx::shape::float_type, {64, 3, 11, 11}}, 18));
auto x_main_module_20 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[4,"
"4],use_dynamic_same_auto_pad:0}"),
x_0,
x_main_module_19);
auto x_main_module_21 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,64,55,55]}"), x_main_module_18);
auto x_main_module_22 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_20, x_main_module_21);
auto x_main_module_23 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_22);
auto x_main_module_24 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_23);
auto x_main_module_25 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[2,2,2,2],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_24,
x_main_module_17);
auto x_main_module_26 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,192,27,27]}"), x_main_module_16);
auto x_main_module_27 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_25, x_main_module_26);
auto x_main_module_28 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_27);
auto x_main_module_29 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_28);
auto x_main_module_30 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_29,
x_main_module_15);
auto x_main_module_31 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,384,13,13]}"), x_main_module_14);
auto x_main_module_32 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_30, x_main_module_31);
auto x_main_module_33 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_32);
auto x_main_module_34 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_33,
x_main_module_13);
auto x_main_module_35 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_12);
auto x_main_module_36 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_34, x_main_module_35);
auto x_main_module_37 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_36);
auto x_main_module_38 = mmain->add_instruction(
migraphx::make_json_op("convolution",
"{dilation:[1,1],group:1,padding:[1,1,1,1],padding_mode:0,stride:[1,"
"1],use_dynamic_same_auto_pad:0}"),
x_main_module_37,
x_main_module_11);
auto x_main_module_39 = mmain->add_instruction(
migraphx::make_json_op("broadcast", "{axis:1,out_lens:[1,256,13,13]}"), x_main_module_10);
auto x_main_module_40 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_38, x_main_module_39);
auto x_main_module_41 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_40);
auto x_main_module_42 = mmain->add_instruction(
migraphx::make_json_op(
"pooling",
"{ceil_mode:0,lengths:[3,3],lp_order:2,mode:1,padding:[0,0,0,0],stride:[2,2]}"),
x_main_module_41);
auto x_main_module_43 =
mmain->add_instruction(migraphx::make_json_op("flatten", "{axis:1}"), x_main_module_42);
auto x_main_module_44 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_43);
auto x_main_module_45 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_9);
auto x_main_module_46 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_44, x_main_module_45);
auto x_main_module_47 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_8);
auto x_main_module_48 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_2);
auto x_main_module_49 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_47, x_main_module_48);
auto x_main_module_50 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_46, x_main_module_49);
auto x_main_module_51 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_50);
auto x_main_module_52 = mmain->add_instruction(migraphx::make_op("identity"), x_main_module_51);
auto x_main_module_53 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_7);
auto x_main_module_54 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_52, x_main_module_53);
auto x_main_module_55 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_6);
auto x_main_module_56 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,4096]}"), x_main_module_1);
auto x_main_module_57 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_55, x_main_module_56);
auto x_main_module_58 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_54, x_main_module_57);
auto x_main_module_59 = mmain->add_instruction(migraphx::make_op("relu"), x_main_module_58);
auto x_main_module_60 = mmain->add_instruction(
migraphx::make_json_op("transpose", "{permutation:[1,0]}"), x_main_module_5);
auto x_main_module_61 =
mmain->add_instruction(migraphx::make_op("dot"), x_main_module_59, x_main_module_60);
auto x_main_module_62 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_4);
auto x_main_module_63 = mmain->add_instruction(
migraphx::make_json_op("multibroadcast", "{out_lens:[1,1000]}"), x_main_module_0);
auto x_main_module_64 =
mmain->add_instruction(migraphx::make_op("mul"), x_main_module_62, x_main_module_63);
auto x_main_module_65 =
mmain->add_instruction(migraphx::make_op("add"), x_main_module_61, x_main_module_64);
mmain->add_return({x_main_module_65});
return p;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMAND_HPP
......@@ -18,7 +41,10 @@ inline namespace MIGRAPHX_INLINE_NS {
inline auto& get_commands()
{
// NOLINTNEXTLINE
static std::unordered_map<std::string, std::function<void(std::vector<std::string> args)>> m;
static std::unordered_map<
std::string,
std::function<void(const std::string& exe_name, std::vector<std::string> args)>>
m;
return m;
}
......@@ -42,10 +68,11 @@ const std::string& command_name()
}
template <class T>
void run_command(std::vector<std::string> args, bool add_help = false)
void run_command(const std::string& exe_name, std::vector<std::string> args, bool add_help = false)
{
T x;
argument_parser ap;
ap.set_exe_name(exe_name + " " + command_name<T>());
if(add_help)
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help());
x.parse(ap);
......@@ -58,7 +85,9 @@ template <class T>
int auto_register_command()
{
auto& m = get_commands();
m[command_name<T>()] = [](std::vector<std::string> args) { run_command<T>(args, true); };
m[command_name<T>()] = [](const std::string& exe_name, std::vector<std::string> args) {
run_command<T>(exe_name, args, true);
};
return 0;
}
......
This diff is collapsed.
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify.hpp"
#include "argument_parser.hpp"
#include "command.hpp"
......@@ -50,8 +73,12 @@ struct loader
void parse(argument_parser& ap)
{
ap(file, {}, ap.metavar("<input file>"));
ap(model, {"--model"}, ap.help("Load model"), ap.type("resnet50|inceptionv3|alexnet"));
ap(file, {}, ap.metavar("<input file>"), ap.file_exist(), ap.required(), ap.group("input"));
ap(model,
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(file_type, {"--migraphx"}, ap.help("Load as MIGraphX"), ap.set_value("migraphx"));
......@@ -187,6 +214,9 @@ struct loader
auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end());
}
// Remove unused variable when exporting to cpp
if(output_type == "cpp")
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}});
if(optimize)
{
migraphx::run_passes(*p.get_main_module(),
......@@ -552,26 +582,62 @@ struct onnx : command<onnx>
struct main_command
{
static std::string get_command_help()
static std::string get_command_help(const std::string& title = colorize(color::fg_yellow,
"COMMANDS:"))
{
std::string result = "Commands:\n";
return std::accumulate(get_commands().begin(),
get_commands().end(),
result,
[](auto r, auto&& p) { return r + " " + p.first + "\n"; });
std::string result = title + "\n";
std::vector<std::string> commands(get_commands().size());
std::transform(get_commands().begin(),
get_commands().end(),
commands.begin(),
[](const auto& p) { return colorize(color::fg_green, p.first); });
std::sort(commands.begin(), commands.end());
return std::accumulate(commands.begin(), commands.end(), result, [](auto r, auto&& s) {
return r + " " + s + "\n";
});
}
void parse(argument_parser& ap)
{
std::string version_str = "MIGraphX Version: " + std::to_string(MIGRAPHX_VERSION_MAJOR) +
"." + std::to_string(MIGRAPHX_VERSION_MINOR);
ap(wrong_commands, {}, ap.metavar("<command>"), ap.append());
ap(nullptr, {"-h", "--help"}, ap.help("Show help"), ap.show_help(get_command_help()));
ap(nullptr,
{"-v", "--version"},
ap.help("Show MIGraphX version"),
ap.show_help(version_str));
// Trim command off of exe name
ap.set_exe_name(ap.get_exe_name().substr(0, ap.get_exe_name().size() - 5));
ap.set_exe_name_to(exe_name);
}
void run() {}
std::vector<std::string> wrong_commands{};
std::string exe_name = "<exe>";
void run()
{
std::cout << color::fg_red << color::bold << "error: " << color::reset;
auto it = std::find_if(wrong_commands.begin(), wrong_commands.end(), [](const auto& c) {
return get_commands().count(c) > 0;
});
if(it == wrong_commands.end())
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
}
else
{
std::cout << "command '" << color::fg_yellow << *it << color::reset
<< "' must be first argument" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " " << exe_name << " " << *it << " <options>" << std::endl;
}
std::cout << std::endl;
}
};
} // namespace MIGRAPHX_INLINE_NS
......@@ -593,11 +659,11 @@ int main(int argc, const char* argv[])
auto cmd = args.front();
if(m.count(cmd) > 0)
{
m.at(cmd)({args.begin() + 1, args.end()});
m.at(cmd)(argv[0], {args.begin() + 1, args.end()});
}
else
{
run_command<main_command>(args);
run_command<main_command>(argv[0], args);
}
return 0;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "marker_roctx.hpp"
#include <migraphx/dynamic_loader.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment