Commit c1ec929c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents abe2a889 03225b57
*.swp #==============================================================================#
# File extensions to be ignored anywhere in the tree.
#==============================================================================#
# Temp files created by most text editors
*~
# Merge files created by git
*.orig
# Byte compiled python modules
*.pyc
*.pyd
# Vim swap files
.*.sw?
.sw?
# Visual Studio
.vs
/.vscode/*
# Sublime Text settings
*.sublime-workspace
*.sublime-project
# Eclipse Project settings
*.*project
.settings
# OS X specific files
.DS_store
#==============================================================================#
# Explicit files to ignore (only matches one).
#==============================================================================#
# Various tags
/tags
/TAGS
/GPATH
/GRTAGS
/GSYMS
/GTAGS
/ID
.gitusers
/compile_commands.json
/CMakeSettings.json
#==============================================================================#
# Directories to ignore (do not add trailing '/'s, they skip symlinks).
#==============================================================================#
# Nested build directory
/build*
# Downloaded models
test/onnx/models
# VS2017 and VSCode config files.
.vscode
.vs
...@@ -38,6 +38,7 @@ add_library(migraphx ...@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp msgpack.cpp
normalize_attributes.cpp normalize_attributes.cpp
normalize_ops.cpp normalize_ops.cpp
op_enums.cpp
operation.cpp operation.cpp
opt/memory_coloring.cpp opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp opt/memory_coloring_impl.cpp
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
...@@ -72,6 +73,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -72,6 +73,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type"); MIGRAPHX_THROW(migraphx_status_bad_param, "Unknown type");
} }
template <class T>
auto to_obj_vector(const T* x, std::size_t n)
{
std::vector<decltype((*x)->object)> result;
std::transform(x, x + n, std::back_inserter(result), [&](auto&& y) { return y->object; });
return result;
}
template <class T, class U>
auto to_objptr_vector(const U* x, std::size_t n)
{
std::vector<T> result;
std::transform(
x, x + n, std::back_inserter(result), [&](auto&& y) { return std::addressof(y->object); });
return result;
}
target get_target(const std::string& name) { return make_target(name); } target get_target(const std::string& name) { return make_target(name); }
void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; } void set_offload_copy(compile_options& options, bool value) { options.offload_copy = value; }
...@@ -194,6 +212,8 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -194,6 +212,8 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
template <class T, class U, class Target = std::remove_pointer_t<T>> template <class T, class U, class Target = std::remove_pointer_t<T>>
...@@ -289,6 +309,36 @@ struct migraphx_shapes ...@@ -289,6 +309,36 @@ struct migraphx_shapes
std::vector<migraphx::shape> object; std::vector<migraphx::shape> object;
}; };
extern "C" struct migraphx_instruction;
struct migraphx_instruction
{
template <class... Ts>
migraphx_instruction(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::instruction_ref object;
};
extern "C" struct migraphx_instructions;
struct migraphx_instructions
{
template <class... Ts>
migraphx_instructions(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
std::vector<migraphx::instruction_ref> object;
};
extern "C" struct migraphx_modules;
struct migraphx_modules
{
template <class... Ts>
migraphx_modules(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
std::vector<migraphx::module*> object;
};
extern "C" struct migraphx_module; extern "C" struct migraphx_module;
struct migraphx_module struct migraphx_module
{ {
...@@ -379,6 +429,16 @@ struct migraphx_quantize_int8_options ...@@ -379,6 +429,16 @@ struct migraphx_quantize_int8_options
migraphx::quantize_int8_options object; migraphx::quantize_int8_options object;
}; };
extern "C" struct migraphx_context;
struct migraphx_context
{
template <class... Ts>
migraphx_context(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::context object;
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -762,6 +822,77 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_ ...@@ -762,6 +822,77 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction)
{
auto api_error_result = migraphx::try_([&] { destroy((instruction)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions)
{
auto api_error_result = migraphx::try_([&] { destroy((instructions)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size)
{
auto api_error_result = migraphx::try_([&] {
*instructions =
object_cast<migraphx_instructions_t>(allocate<std::vector<migraphx::instruction_ref>>(
migraphx::to_obj_vector<const_migraphx_instruction_t>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_destroy(migraphx_modules_t modules)
{
auto api_error_result = migraphx::try_([&] { destroy((modules)); });
return api_error_result;
}
extern "C" migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input)
{
auto api_error_result = migraphx::try_([&] { *output = *input; });
return api_error_result;
}
extern "C" migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size)
{
auto api_error_result = migraphx::try_([&] {
*modules = object_cast<migraphx_modules_t>(allocate<std::vector<migraphx::module*>>(
migraphx::to_objptr_vector<migraphx::module*>((ptr), (size))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_create(migraphx_module_t* module, char* name)
{
auto api_error_result = migraphx::try_([&] {
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
*module = object_cast<migraphx_module_t>(allocate<migraphx::module>((std::string(name))));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module) extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -772,6 +903,76 @@ extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module) ...@@ -772,6 +903,76 @@ extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object)));
});
return api_error_result;
}
extern "C" migraphx_status
migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(op == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter op: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
if(module_refs == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module_refs: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_instruction((op->object), (args->object), (module_refs->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = allocate<migraphx_instruction_t>(
(module->object).add_parameter((name), (shape->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args)
{
auto api_error_result = migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
if(args == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter args: Null pointer");
*out = allocate<migraphx_instruction_t>((module->object).add_return((args->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
auto api_error_result = migraphx::try_([&] { destroy((program)); }); auto api_error_result = migraphx::try_([&] { destroy((program)); });
...@@ -785,6 +986,13 @@ extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output, ...@@ -785,6 +986,13 @@ extern "C" migraphx_status migraphx_program_assign_to(migraphx_program_t output,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_create(migraphx_program_t* program)
{
auto api_error_result = migraphx::try_(
[&] { *program = object_cast<migraphx_program_t>(allocate<migraphx::program>()); });
return api_error_result;
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program) migraphx_program_t program)
{ {
...@@ -796,6 +1004,17 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o ...@@ -796,6 +1004,17 @@ extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* o
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_create_module(migraphx_module_t* out, migraphx_program_t program, const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).create_module((name)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options) migraphx_compile_options_t options)
...@@ -883,6 +1102,17 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -883,6 +1102,17 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_program_experimental_get_context(migraphx_context_t* out, const_migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_context_t>(migraphx::get_context((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation) extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{ {
auto api_error_result = migraphx::try_([&] { destroy((operation)); }); auto api_error_result = migraphx::try_([&] { destroy((operation)); });
...@@ -1324,3 +1554,13 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -1324,3 +1554,13 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
}); });
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
(context->object).finish();
});
return api_error_result;
}
...@@ -64,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t; ...@@ -64,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t; typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t; typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_instruction* migraphx_instruction_t;
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
typedef struct migraphx_instructions* migraphx_instructions_t;
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
typedef struct migraphx_modules* migraphx_modules_t;
typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t; typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t; typedef const struct migraphx_module* const_migraphx_module_t;
...@@ -91,6 +100,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name ...@@ -91,6 +100,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t; typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t; typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
...@@ -198,16 +210,66 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); ...@@ -198,16 +210,66 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input);
migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module); migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output, migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input); const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out, migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program); migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program,
const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program, migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options_t options); migraphx_compile_options_t options);
...@@ -229,6 +291,9 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -229,6 +291,9 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out,
const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
...@@ -355,6 +420,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -355,6 +420,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options); migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
#include <exception> #include <exception>
...@@ -523,12 +525,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -523,12 +525,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
}; };
}; };
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
};
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); }
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
instructions(Ts... xs)
{
std::array<const_migraphx_instruction_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_instructions_create, a.data(), a.size());
}
};
struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module struct module
{ {
migraphx_module_t mm; migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {} module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); } void print() const { call(&migraphx_module_print, mm); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_instruction(const migraphx::operation& op,
const migraphx::instructions& args,
const migraphx::modules& module_args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr());
return instruction(ret_ins, own{});
}
};
struct context
{
migraphx_context_t ctx;
void finish() const { call(&migraphx_context_finish, ctx); }
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
...@@ -557,7 +663,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) ...@@ -557,7 +663,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed /// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program) struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() {} program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); } program(migraphx_program* p, own) { this->set_handle(p, own{}); }
...@@ -627,27 +733,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -627,27 +733,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu}; return module{p_modu};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } context experimental_get_context()
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{ {
this->make_handle(&migraphx_operation_create, name, attributes, xs...); migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx};
} }
std::string name() module create_module(const std::string& name)
{ {
std::array<char, 1024> out_name; migraphx_module_t p_modu;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr()); call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return {out_name.data()}; return module{p_modu};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
}; };
// options for migraphx file format options // options for migraphx file format options
......
...@@ -178,14 +178,55 @@ def shapes(h): ...@@ -178,14 +178,55 @@ def shapes(h):
returns='const migraphx::shape&') returns='const migraphx::shape&')
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
def instruction(h):
pass
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
def instructions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
def modules(h):
h.constructor('create',
api.params(ptr='migraphx_module_t*', size='size_t'),
fname='migraphx::to_objptr_vector<migraphx::module*>')
@auto_handle(ref=True) @auto_handle(ref=True)
def module(h): def module(h):
h.constructor('create', api.params(name='std::string'))
h.method('print', invoke='migraphx::print_module($@)', const=True) h.method('print', invoke='migraphx::print_module($@)', const=True)
h.method('add_instruction',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_instruction_with_mod_args',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>',
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
@auto_handle() @auto_handle()
def program(h): def program(h):
h.constructor('create')
h.method('get_main_module', returns='migraphx::module*') h.method('get_main_module', returns='migraphx::module*')
h.method('create_module',
api.params(name='const char*'),
returns='migraphx::module*')
h.method( h.method(
'compile', 'compile',
api.params(target='migraphx::target', api.params(target='migraphx::target',
...@@ -207,6 +248,10 @@ def program(h): ...@@ -207,6 +248,10 @@ def program(h):
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('experimental_get_context',
invoke='migraphx::get_context($@)',
const=True,
returns='migraphx::context')
@auto_handle() @auto_handle()
...@@ -353,3 +398,8 @@ api.add_function('migraphx_quantize_int8', ...@@ -353,3 +398,8 @@ api.add_function('migraphx_quantize_int8',
target='migraphx::target', target='migraphx::target',
options='migraphx::quantize_int8_options'), options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap') fname='migraphx::quantize_int8_wrap')
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19; migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18); auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20; migraphx::op::pooling pooling20;
pooling20.mode = "max"; pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0}; pooling20.padding = {0, 0};
pooling20.stride = {2, 2}; pooling20.stride = {2, 2};
pooling20.lengths = {3, 3}; pooling20.lengths = {3, 3};
...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24; migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23); auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25; migraphx::op::pooling pooling25;
pooling25.mode = "max"; pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0}; pooling25.padding = {0, 0};
pooling25.stride = {2, 2}; pooling25.stride = {2, 2};
pooling25.lengths = {3, 3}; pooling25.lengths = {3, 3};
...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size) ...@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37; migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36); auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38; migraphx::op::pooling pooling38;
pooling38.mode = "max"; pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0}; pooling38.padding = {0, 0};
pooling38.stride = {2, 2}; pooling38.stride = {2, 2};
pooling38.lengths = {3, 3}; pooling38.lengths = {3, 3};
......
...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492; migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491); auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493; migraphx::op::pooling pooling493;
pooling493.mode = "max"; pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0}; pooling493.padding = {0, 0};
pooling493.stride = {2, 2}; pooling493.stride = {2, 2};
pooling493.lengths = {3, 3}; pooling493.lengths = {3, 3};
...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499; migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498); auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500; migraphx::op::pooling pooling500;
pooling500.mode = "max"; pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0}; pooling500.padding = {0, 0};
pooling500.stride = {2, 2}; pooling500.stride = {2, 2};
pooling500.lengths = {3, 3}; pooling500.lengths = {3, 3};
...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518; migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517); auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519; migraphx::op::pooling pooling519;
pooling519.mode = "average"; pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1}; pooling519.padding = {1, 1};
pooling519.stride = {1, 1}; pooling519.stride = {1, 1};
pooling519.lengths = {3, 3}; pooling519.lengths = {3, 3};
...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541; migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540); auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542; migraphx::op::pooling pooling542;
pooling542.mode = "average"; pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1}; pooling542.padding = {1, 1};
pooling542.stride = {1, 1}; pooling542.stride = {1, 1};
pooling542.lengths = {3, 3}; pooling542.lengths = {3, 3};
...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564; migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563); auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565; migraphx::op::pooling pooling565;
pooling565.mode = "average"; pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1}; pooling565.padding = {1, 1};
pooling565.stride = {1, 1}; pooling565.stride = {1, 1};
pooling565.lengths = {3, 3}; pooling565.lengths = {3, 3};
...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581; migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580); auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582; migraphx::op::pooling pooling582;
pooling582.mode = "max"; pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0}; pooling582.padding = {0, 0};
pooling582.stride = {2, 2}; pooling582.stride = {2, 2};
pooling582.lengths = {3, 3}; pooling582.lengths = {3, 3};
...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610; migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609); auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611; migraphx::op::pooling pooling611;
pooling611.mode = "average"; pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1}; pooling611.padding = {1, 1};
pooling611.stride = {1, 1}; pooling611.stride = {1, 1};
pooling611.lengths = {3, 3}; pooling611.lengths = {3, 3};
...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642; migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641); auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643; migraphx::op::pooling pooling643;
pooling643.mode = "average"; pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1}; pooling643.padding = {1, 1};
pooling643.stride = {1, 1}; pooling643.stride = {1, 1};
pooling643.lengths = {3, 3}; pooling643.lengths = {3, 3};
...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674; migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673); auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675; migraphx::op::pooling pooling675;
pooling675.mode = "average"; pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1}; pooling675.padding = {1, 1};
pooling675.stride = {1, 1}; pooling675.stride = {1, 1};
pooling675.lengths = {3, 3}; pooling675.lengths = {3, 3};
...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706; migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705); auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707; migraphx::op::pooling pooling707;
pooling707.mode = "average"; pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1}; pooling707.padding = {1, 1};
pooling707.stride = {1, 1}; pooling707.stride = {1, 1};
pooling707.lengths = {3, 3}; pooling707.lengths = {3, 3};
...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729; migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728); auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730; migraphx::op::pooling pooling730;
pooling730.mode = "max"; pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0}; pooling730.padding = {0, 0};
pooling730.stride = {2, 2}; pooling730.stride = {2, 2};
pooling730.lengths = {3, 3}; pooling730.lengths = {3, 3};
...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1; concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756); auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758; migraphx::op::pooling pooling758;
pooling758.mode = "average"; pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1}; pooling758.padding = {1, 1};
pooling758.stride = {1, 1}; pooling758.stride = {1, 1};
pooling758.lengths = {3, 3}; pooling758.lengths = {3, 3};
...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1; concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787); auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789; migraphx::op::pooling pooling789;
pooling789.mode = "average"; pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1}; pooling789.padding = {1, 1};
pooling789.stride = {1, 1}; pooling789.stride = {1, 1};
pooling789.lengths = {3, 3}; pooling789.lengths = {3, 3};
...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz ...@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1; concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792); auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794; migraphx::op::pooling pooling794;
pooling794.mode = "average"; pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0}; pooling794.padding = {0, 0};
pooling794.stride = {8, 8}; pooling794.stride = {8, 8};
pooling794.lengths = {8, 8}; pooling794.lengths = {8, 8};
......
...@@ -505,8 +505,10 @@ struct roctx : command<roctx> ...@@ -505,8 +505,10 @@ struct roctx : command<roctx>
struct op : command<op> struct op : command<op>
{ {
bool show_ops = false; bool show_ops = false;
std::string op_name{};
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
ap(show_ops, ap(show_ops,
{"--list", "-l"}, {"--list", "-l"},
ap.help("List all the operators of MIGraphX"), ap.help("List all the operators of MIGraphX"),
...@@ -519,6 +521,12 @@ struct op : command<op> ...@@ -519,6 +521,12 @@ struct op : command<op>
for(const auto& name : get_operators()) for(const auto& name : get_operators())
std::cout << name << std::endl; std::cout << name << std::endl;
} }
else
{
auto op = load_op(op_name);
std::cout << op_name << ": " << std::endl;
std::cout << to_pretty_json_string(op.to_value()) << std::endl;
}
} }
}; };
......
...@@ -87,6 +87,6 @@ target get_target(bool gpu) ...@@ -87,6 +87,6 @@ target get_target(bool gpu)
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); } void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace driver } // namespace driver
} // namespace migraphx } // namespace migraphx
...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu269; migraphx::op::relu relu269;
auto mx269 = mm->add_instruction(relu269, mx268); auto mx269 = mm->add_instruction(relu269, mx268);
migraphx::op::pooling pooling270; migraphx::op::pooling pooling270;
pooling270.mode = "max"; pooling270.mode = migraphx::op::pooling_mode::max;
pooling270.padding = {1, 1}; pooling270.padding = {1, 1};
pooling270.stride = {2, 2}; pooling270.stride = {2, 2};
pooling270.lengths = {3, 3}; pooling270.lengths = {3, 3};
...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size) ...@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu438; migraphx::op::relu relu438;
auto mx438 = mm->add_instruction(relu438, mx437); auto mx438 = mm->add_instruction(relu438, mx437);
migraphx::op::pooling pooling439; migraphx::op::pooling pooling439;
pooling439.mode = "average"; pooling439.mode = migraphx::op::pooling_mode::average;
pooling439.padding = {0, 0}; pooling439.padding = {0, 0};
pooling439.stride = {1, 1}; pooling439.stride = {1, 1};
pooling439.lengths = {7, 7}; pooling439.lengths = {7, 7};
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
std::string to_pretty_json_string(const value& val, std::size_t indent = 4);
std::string to_json_string(const value& val); std::string to_json_string(const value& val);
value from_json_string(const std::string& str); value from_json_string(const std::string& str);
value from_json_string(const char* str, std::size_t size); value from_json_string(const char* str, std::size_t size);
......
#ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP #define MIGRAPHX_GUARD_OPERATORS_COMMON_HPP
#include <ostream>
#include <vector>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <utility> #include <utility>
...@@ -15,6 +17,14 @@ enum padding_mode_t ...@@ -15,6 +17,14 @@ enum padding_mode_t
valid valid
}; };
// The pooling modes must correspond 1-1 to the operators defined for struct parse_pooling.
// Used in pooling and roialign operators.
enum class pooling_mode
{
average,
max
};
// indicate rnn computation direction // indicate rnn computation direction
enum class rnn_direction enum class rnn_direction
{ {
...@@ -23,6 +33,7 @@ enum class rnn_direction ...@@ -23,6 +33,7 @@ enum class rnn_direction
bidirectional, bidirectional,
}; };
std::ostream& operator<<(std::ostream& os, pooling_mode v);
std::ostream& operator<<(std::ostream& os, rnn_direction v); std::ostream& operator<<(std::ostream& os, rnn_direction v);
} // namespace op } // namespace op
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
struct pooling struct pooling
{ {
std::string mode = "average"; pooling_mode mode = {pooling_mode::average};
std::vector<std::size_t> padding = {0, 0}; std::vector<std::size_t> padding = {0, 0};
std::vector<std::size_t> stride = {1, 1}; std::vector<std::size_t> stride = {1, 1};
std::vector<std::size_t> lengths = {1, 1}; std::vector<std::size_t> lengths = {1, 1};
......
...@@ -38,18 +38,38 @@ struct prefix_scan_op : op_name<Derived> ...@@ -38,18 +38,38 @@ struct prefix_scan_op : op_name<Derived>
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
return inputs.at(0); auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
} }
argument compute(const shape&, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result = args[0].copy(); argument result{output_shape};
auto s = result.get_shape(); auto s = args[0].get_shape();
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}}; if(s == output_shape)
auto lens = s.lens(); {
lens[axis] = 1; result = args[0].copy();
auto batch = shape{s.type(), lens, s.strides()}; }
auto& self = static_cast<const Derived&>(*this); else
{
visit_all(result, args[0])([&](auto output, auto input) {
par_for(output_shape.elements(),
[&](auto i) { output[output_shape.index(i)] = input[s.index(i)]; });
});
s = output_shape;
}
auto slice = shape{s.type(), {s.lens()[axis]}, {s.strides()[axis]}};
auto lens = s.lens();
lens[axis] = 1;
auto batch = shape{s.type(), lens, s.strides()};
auto& self = static_cast<const Derived&>(*this);
result.visit([&](auto output) { result.visit([&](auto output) {
using type = decltype(output); using type = decltype(output);
par_for(batch.elements(), [&](auto i) { par_for(batch.elements(), [&](auto i) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <limits> #include <limits>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
...@@ -21,7 +22,7 @@ namespace op { ...@@ -21,7 +22,7 @@ namespace op {
struct roialign struct roialign
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode = "half_pixel";
std::string mode = "avg"; pooling_mode mode = {pooling_mode::average};
int64_t output_height = 1; int64_t output_height = 1;
int64_t output_width = 1; int64_t output_width = 1;
int64_t sampling_ratio = 0; int64_t sampling_ratio = 0;
...@@ -241,16 +242,17 @@ struct roialign ...@@ -241,16 +242,17 @@ struct roialign
in_dims[0] * in_dims[1]); in_dims[0] * in_dims[1]);
double output_val; double output_val;
std::tie(output_val, vec_index[c]) = std::tie(output_val, vec_index[c]) =
(mode == "avg") ? this->calc_pooling(offset_bottom_data, (mode == migraphx::op::pooling_mode::average)
bin_grid_size, ? this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
avg_pool{}) vec_index[c],
: this->calc_pooling(offset_bottom_data, avg_pool{})
bin_grid_size, : this->calc_pooling(offset_bottom_data,
pre_calc, bin_grid_size,
vec_index[c], pre_calc,
max_pool{}); vec_index[c],
max_pool{});
output(n, c, ph, pw) = output_val; output(n, c, ph, pw) = output_val;
}); });
}); });
......
...@@ -461,8 +461,8 @@ lifetime get_lifetime_op(const T&) ...@@ -461,8 +461,8 @@ lifetime get_lifetime_op(const T&)
* shape compute_shape(const std::vector<shape>& input) const; * shape compute_shape(const std::vector<shape>& input) const;
* shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>& * shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
* mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>& * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
* input) const; argument compute(const shape& output,const std::vector<argument>& input) * input) const; argument compute(const shape& output,const std::vector<argument>& input) const;
* const; argument compute(const shape& output,const std::vector<argument>& input,const * argument compute(const shape& output,const std::vector<argument>& input,const
* std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
* std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
* shape& output,const std::vector<argument>& input,const std::vector<module_ref>& * shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
......
...@@ -82,6 +82,9 @@ struct program ...@@ -82,6 +82,9 @@ struct program
const std::function<void(instruction_ref, const std::function<void(instruction_ref,
std::unordered_map<instruction_ref, std::string>)>& std::unordered_map<instruction_ref, std::string>)>&
print_func) const; print_func) const;
void print(const std::function<void(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
......
...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins, ...@@ -44,7 +44,7 @@ static void update_op(const instruction_ref& input, const instruction_ref& ins,
static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m) static void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& m)
{ {
auto op = any_cast<op::pooling>(ins->get_operator()); auto op = any_cast<op::pooling>(ins->get_operator());
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
return; return;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment