Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......
......@@ -50,10 +50,10 @@
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists(\"yolov4_fp16.msgpack\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.msgpack\n",
"if not os.path.exists(\"yolov4.msgpack\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.msgpack"
"if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
]
},
{
......@@ -115,8 +115,8 @@
"outputs": [],
"source": [
"# Load serialized model (either single- or half-precision)\n",
"model = migraphx.load(\"yolov4.msgpack\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.msgpack\", format=\"msgpack\")\n",
"model = migraphx.load(\"yolov4.mxr\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.mxr\", format=\"msgpack\")\n",
"\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"input_name = next(iter(model.get_parameter_shapes()))\n",
......@@ -192,4 +192,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
\ No newline at end of file
}
......@@ -12,7 +12,7 @@ RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/4.5
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
apt-utils \
build-essential \
clang-format-5.0 \
clang-format-10 \
cmake \
curl \
doxygen \
......
......@@ -38,6 +38,7 @@ add_library(migraphx
msgpack.cpp
normalize_attributes.cpp
normalize_ops.cpp
op_enums.cpp
operation.cpp
opt/memory_coloring.cpp
opt/memory_coloring_impl.cpp
......@@ -114,6 +115,7 @@ register_migraphx_ops(
identity
if_op
im2col
isnan
leaky_relu
less
load
......@@ -161,6 +163,9 @@ register_migraphx_ops(
rsqrt
scalar
scatter
scatternd_none
scatternd_add
scatternd_mul
sigmoid
sign
sinh
......@@ -211,7 +216,6 @@ target_link_libraries(migraphx PRIVATE msgpackc-cxx)
target_link_libraries(migraphx INTERFACE $<BUILD_INTERFACE:msgpackc-cxx>)
add_library(migraphx_all_targets INTERFACE)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
set(PACKAGE_DEPENDS)
......@@ -222,6 +226,7 @@ add_subdirectory(tf)
add_subdirectory(py)
add_subdirectory(targets/ref)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
if(MIGRAPHX_ENABLE_CPU)
add_subdirectory(targets/cpu)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu)
......@@ -239,7 +244,7 @@ if(HAVE_HALF_EXPR)
endif()
rocm_export_targets(
TARGETS migraphx::migraphx migraphx_all_targets
TARGETS migraphx::migraphx_c
NAMESPACE migraphx::
DEPENDS
Threads
......
This diff is collapsed.
......@@ -25,7 +25,8 @@ extern "C" {
#endif
// return code, more to be added later
typedef enum {
typedef enum
{
migraphx_status_success = 0,
migraphx_status_bad_param = 1,
migraphx_status_unknown_target = 3,
......@@ -35,7 +36,8 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
typedef enum {
typedef enum
{
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
......@@ -62,6 +64,15 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_instruction* migraphx_instruction_t;
typedef const struct migraphx_instruction* const_migraphx_instruction_t;
typedef struct migraphx_instructions* migraphx_instructions_t;
typedef const struct migraphx_instructions* const_migraphx_instructions_t;
typedef struct migraphx_modules* migraphx_modules_t;
typedef const struct migraphx_modules* const_migraphx_modules_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
......@@ -89,8 +100,24 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;
typedef struct migraphx_experimental_custom_op* migraphx_experimental_custom_op_t;
typedef const struct migraphx_experimental_custom_op* const_migraphx_experimental_custom_op_t;
typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraphx_shape_t out,
void* obj,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
migraphx_shape_datatype_t type,
size_t* lengths,
......@@ -121,6 +148,9 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
const_migraphx_argument_t input);
migraphx_status
migraphx_argument_create(migraphx_argument_t* argument, const_migraphx_shape_t shape, void* buffer);
......@@ -137,11 +167,17 @@ migraphx_argument_generate(migraphx_argument_t* out, const_migraphx_shape_t s, s
migraphx_status migraphx_target_destroy(migraphx_target_t target);
migraphx_status migraphx_target_assign_to(migraphx_target_t output, const_migraphx_target_t input);
migraphx_status migraphx_target_create(migraphx_target_t* target, const char* name);
migraphx_status migraphx_program_parameter_shapes_destroy(
migraphx_program_parameter_shapes_t program_parameter_shapes);
migraphx_status
migraphx_program_parameter_shapes_assign_to(migraphx_program_parameter_shapes_t output,
const_migraphx_program_parameter_shapes_t input);
migraphx_status migraphx_program_parameter_shapes_size(
size_t* out, migraphx_program_parameter_shapes_t program_parameter_shapes);
......@@ -156,6 +192,9 @@ migraphx_status migraphx_program_parameter_shapes_names(
migraphx_status
migraphx_program_parameters_destroy(migraphx_program_parameters_t program_parameters);
migraphx_status migraphx_program_parameters_assign_to(migraphx_program_parameters_t output,
const_migraphx_program_parameters_t input);
migraphx_status
migraphx_program_parameters_create(migraphx_program_parameters_t* program_parameters);
......@@ -165,6 +204,9 @@ migraphx_status migraphx_program_parameters_add(migraphx_program_parameters_t pr
migraphx_status migraphx_arguments_destroy(migraphx_arguments_t arguments);
migraphx_status migraphx_arguments_assign_to(migraphx_arguments_t output,
const_migraphx_arguments_t input);
migraphx_status migraphx_arguments_size(size_t* out, migraphx_arguments_t arguments);
migraphx_status
......@@ -172,18 +214,73 @@ migraphx_arguments_get(const_migraphx_argument_t* out, migraphx_arguments_t argu
migraphx_status migraphx_shapes_destroy(migraphx_shapes_t shapes);
migraphx_status migraphx_shapes_assign_to(migraphx_shapes_t output, const_migraphx_shapes_t input);
migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_instruction_destroy(migraphx_instruction_t instruction);
migraphx_status migraphx_instruction_assign_to(migraphx_instruction_t output,
const_migraphx_instruction_t input);
migraphx_status migraphx_instructions_destroy(migraphx_instructions_t instructions);
migraphx_status migraphx_instructions_assign_to(migraphx_instructions_t output,
const_migraphx_instructions_t input);
migraphx_status migraphx_instructions_create(migraphx_instructions_t* instructions,
const_migraphx_instruction_t* ptr,
size_t size);
migraphx_status migraphx_modules_destroy(migraphx_modules_t modules);
migraphx_status migraphx_modules_assign_to(migraphx_modules_t output,
const_migraphx_modules_t input);
migraphx_status
migraphx_modules_create(migraphx_modules_t* modules, migraphx_module_t* ptr, size_t size);
migraphx_status migraphx_module_create(migraphx_module_t* module, char* name);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_module_add_instruction(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args);
migraphx_status migraphx_module_add_instruction_with_mod_args(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_operation_t op,
migraphx_instructions_t args,
migraphx_modules_t module_refs);
migraphx_status migraphx_module_add_parameter(migraphx_instruction_t* out,
migraphx_module_t module,
const char* name,
const_migraphx_shape_t shape);
migraphx_status migraphx_module_add_return(migraphx_instruction_t* out,
migraphx_module_t module,
migraphx_instructions_t args);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_assign_to(migraphx_program_t output,
const_migraphx_program_t input);
migraphx_status migraphx_program_create(migraphx_program_t* program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_create_module(migraphx_module_t* out,
migraphx_program_t program,
const char* name);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options_t options);
......@@ -205,8 +302,14 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_program_experimental_get_context(migraphx_context_t* out,
const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
const_migraphx_operation_t input);
migraphx_status migraphx_operation_create(migraphx_operation_t* operation,
const char* name,
const char* attributes,
......@@ -222,6 +325,9 @@ migraphx_save(migraphx_program_t p, const char* name, migraphx_file_options_t op
migraphx_status migraphx_onnx_options_destroy(migraphx_onnx_options_t onnx_options);
migraphx_status migraphx_onnx_options_assign_to(migraphx_onnx_options_t output,
const_migraphx_onnx_options_t input);
migraphx_status migraphx_onnx_options_create(migraphx_onnx_options_t* onnx_options);
migraphx_status migraphx_onnx_options_set_input_parameter_shape(
......@@ -236,6 +342,9 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options);
migraphx_status migraphx_file_options_assign_to(migraphx_file_options_t output,
const_migraphx_file_options_t input);
migraphx_status migraphx_file_options_create(migraphx_file_options_t* file_options);
migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t file_options,
......@@ -243,6 +352,9 @@ migraphx_status migraphx_file_options_set_file_format(migraphx_file_options_t fi
migraphx_status migraphx_compile_options_destroy(migraphx_compile_options_t compile_options);
migraphx_status migraphx_compile_options_assign_to(migraphx_compile_options_t output,
const_migraphx_compile_options_t input);
migraphx_status migraphx_compile_options_create(migraphx_compile_options_t* compile_options);
migraphx_status
......@@ -261,6 +373,9 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx_status migraphx_tf_options_destroy(migraphx_tf_options_t tf_options);
migraphx_status migraphx_tf_options_assign_to(migraphx_tf_options_t output,
const_migraphx_tf_options_t input);
migraphx_status migraphx_tf_options_create(migraphx_tf_options_t* tf_options);
migraphx_status migraphx_tf_options_set_nhwc(migraphx_tf_options_t tf_options, bool is_nhwc);
......@@ -282,6 +397,9 @@ migraphx_parse_tf(migraphx_program_t* out, const char* name, migraphx_tf_options
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_assign_to(migraphx_quantize_op_names_t output,
const_migraphx_quantize_op_names_t input);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
......@@ -295,6 +413,10 @@ migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_assign_to(migraphx_quantize_int8_options_t output,
const_migraphx_quantize_int8_options_t input);
migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
......@@ -309,6 +431,28 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
migraphx_status
migraphx_experimental_custom_op_assign_to(migraphx_experimental_custom_op_t output,
const_migraphx_experimental_custom_op_t input);
migraphx_status
migraphx_experimental_custom_op_create(migraphx_experimental_custom_op_t* experimental_custom_op,
void* obj,
migraphx_experimental_custom_op_copy c,
migraphx_experimental_custom_op_delete d,
const char* name);
migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
#ifdef __cplusplus
}
#endif
......
#ifndef MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h"
#include <initializer_list>
#include <migraphx/migraphx.h>
#include <memory>
#include <exception>
......@@ -13,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT
#endif
template <int N>
struct rank : rank<N - 1>
{
};
template <>
struct rank<0>
{
};
template <class T, class F, class... Ts>
T* make(F f, Ts&&... xs)
{
......@@ -152,6 +164,35 @@ struct array_base
}
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template <class T>
struct holder
{
// Friend injection
friend auto migraphx_adl_handle_lookup(holder<T>);
// Function left unimplemented since its only used in non-evaluated
// context
T get() const;
};
template <class C, class T>
struct handle_lookup
{
friend auto migraphx_adl_handle_lookup(holder<T>) { return holder<C>{}; }
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template <class T>
using as_handle = decltype(
migraphx_adl_handle_lookup(holder<std::remove_cv_t<std::remove_pointer_t<T>>>{}).get());
struct own
{
};
......@@ -159,8 +200,8 @@ struct borrow
{
};
template <class T, class D, D Deleter>
struct handle_base
template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{
handle_base() : m_handle(nullptr) {}
template <class F, class... Ts>
......@@ -190,17 +231,158 @@ struct handle_base
m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
}
template <class U>
void assign_to_handle(U* x)
{
Assigner(x, this->get_handle_ptr());
}
protected:
std::shared_ptr<T> m_handle;
};
template <class Base>
struct interface_base : Base
{
interface_base() : Base() {}
protected:
template <class F>
static migraphx_status try_(F f) // NOLINT
{
try
{
f();
return migraphx_status_success;
}
catch(...)
{
return migraphx_status_unknown_error;
}
}
template <class F, class T, class... Ts>
void make_interface(F f, T& obj, Ts&&... xs)
{
auto copy = [](void** out, void* input) {
return try_([&] {
T** y = reinterpret_cast<T**>(out);
T* x = reinterpret_cast<T*>(input);
assert(x != nullptr and y != nullptr and *y == nullptr);
*y = new T(*x); // NOLINT
});
};
auto del = [](void* input) {
return try_([&] {
T* x = reinterpret_cast<T*>(input);
delete x; // NOLINT
});
};
this->make_handle(f, &obj, copy, del, std::forward<Ts>(xs)...);
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter, this->get_handle_ptr(), [](auto... xs) -> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<1>{}, f, xs...); });
});
}
template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f)
{
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) {
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...);
});
}
struct no_out_arg
{
};
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...);
}
template <class T,
class F,
class R,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result, xs...);
}
template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs)
{
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...));
}
template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs)
{
f(std::forward<Ts>(xs)...);
}
template <class T, class = std::enable_if_t<std::is_fundamental<T>{} or std::is_enum<T>{}>>
T auto_convert_param(rank<0>, T x)
{
return x;
}
template <class T>
auto auto_convert_param(rank<1>, T x) -> decltype(as_handle<T>{x})
{
return as_handle<T>{x};
}
template <class T>
auto auto_convert_param(rank<2>, T x) -> decltype(as_handle<T>{x, borrow{}})
{
return as_handle<T>{x, borrow{}};
}
template <class T, class U>
void auto_assign(rank<0>, T* out, U x)
{
return *out = x;
}
template <class T, class U>
auto auto_assign(rank<1>, T* out, U x) -> decltype(x.assign_to_handle(out))
{
x.assign_to_handle(out);
}
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); })
template <class Base, class T>
using require_interface =
std::enable_if_t<std::is_base_of<Base, T>{} and not std::is_same<T, Base>{} and
std::is_copy_constructible<T>{} and std::is_final<T>{}>;
#ifdef DOXYGEN
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy>
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \
migraphx_##name##_assign_to>
#endif
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_BASE(name) MIGRAPHX_DETAIL_HANDLE_BASE(name, )
......@@ -485,12 +667,116 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
}
std::string name()
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
}
};
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); }
};
struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
instructions(Ts... xs)
{
std::array<const_migraphx_instruction_t, sizeof...(Ts)> a{xs.get_handle_ptr()...};
this->make_handle(&migraphx_instructions_create, a.data(), a.size());
}
};
struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules)
{
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
modules(Ts... xs)
{
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...};
this->make_handle(&migraphx_modules_create, a.data(), a.size());
}
};
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_instruction(const migraphx::operation& op,
const migraphx::instructions& args,
const migraphx::modules& module_args)
{
migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args,
&op_ins,
mm,
op.get_handle_ptr(),
args.get_handle_ptr(),
module_args.get_handle_ptr());
return instruction(op_ins, own{});
}
instruction add_parameter(const std::string& name, shape s)
{
migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{});
}
instruction add_return(const migraphx::instructions& args)
{
migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr());
return instruction(ret_ins, own{});
}
};
struct context
{
migraphx_context_t ctx;
void finish() const { call(&migraphx_context_finish, ctx); }
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
......@@ -519,7 +805,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
/// A program represents the all computation graphs to be compiled and executed
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() {}
program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); }
......@@ -589,27 +875,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
struct operation : MIGRAPHX_HANDLE_BASE(operation)
{
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); }
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs)
context experimental_get_context()
{
this->make_handle(&migraphx_operation_create, name, attributes, xs...);
migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx};
}
std::string name()
module create_module(const std::string& name)
{
std::array<char, 1024> out_name;
call(&migraphx_operation_name, out_name.data(), 1024, this->get_handle_ptr());
return {out_name.data()};
migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
// options for migraphx file format options
......@@ -850,6 +1130,32 @@ quantize_int8(const program& prog, const target& ptarget, const quantize_int8_op
options.get_handle_ptr());
}
struct experimental_custom_op_base
{
virtual std::string name() const = 0;
virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default;
};
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
{
template <class T>
experimental_custom_op(T& obj)
{
this->make_interface(&migraphx_experimental_custom_op_create, obj, obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape);
}
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
};
template <class T, class = require_interface<experimental_custom_op_base, T>>
void register_experimental_custom_op(T& obj)
{
experimental_custom_op op{obj};
op.register_op();
}
#ifndef DOXYGEN
} // namespace api
#endif
......
......@@ -178,14 +178,55 @@ def shapes(h):
returns='const migraphx::shape&')
@api.handle('migraphx_instruction', 'migraphx::instruction_ref')
def instruction(h):
pass
@api.handle('migraphx_instructions', 'std::vector<migraphx::instruction_ref>')
def instructions(h):
h.constructor(
'create',
api.params(ptr='const_migraphx_instruction_t*', size='size_t'),
fname='migraphx::to_obj_vector<const_migraphx_instruction_t>')
@api.handle('migraphx_modules', 'std::vector<migraphx::module*>')
def modules(h):
h.constructor('create',
api.params(ptr='migraphx_module_t*', size='size_t'),
fname='migraphx::to_objptr_vector<migraphx::module*>')
@auto_handle(ref=True)
def module(h):
h.constructor('create', api.params(name='std::string'))
h.method('print', invoke='migraphx::print_module($@)', const=True)
h.method('add_instruction',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
h.method('add_instruction_with_mod_args',
api.params(op='migraphx::operation',
args='std::vector<migraphx::instruction_ref>',
module_refs='std::vector<migraphx::module*>'),
fname='add_instruction',
returns='migraphx::instruction_ref')
h.method('add_parameter',
api.params(name='const char*', shape='const migraphx::shape&'),
returns='migraphx::instruction_ref')
h.method('add_return',
api.params(args='std::vector<migraphx::instruction_ref>'),
returns='migraphx::instruction_ref')
@auto_handle()
def program(h):
h.constructor('create')
h.method('get_main_module', returns='migraphx::module*')
h.method('create_module',
api.params(name='const char*'),
returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
......@@ -207,6 +248,10 @@ def program(h):
invoke='migraphx::equal($@)',
returns='bool',
const=True)
h.method('experimental_get_context',
invoke='migraphx::get_context($@)',
const=True,
returns='migraphx::context')
@auto_handle()
......@@ -353,3 +398,18 @@ api.add_function('migraphx_quantize_int8',
target='migraphx::target',
options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap')
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
@api.interface('migraphx_experimental_custom_op',
'migraphx::experimental_custom_op')
def experimental_custom_op(h):
h.constructor('create', api.params(name='const char*'))
h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape')
h.method('register', invoke='migraphx::register_custom_op($@)')
......@@ -106,7 +106,11 @@ bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; }
argument argument::reshape(const shape& s) const
{
assert(s.element_space() <= this->get_shape().element_space());
return {s, this->m_data};
}
argument::data_t argument::data_t::share() const
{
......
......@@ -10,8 +10,35 @@ inline namespace MIGRAPHX_INLINE_NS {
void auto_contiguous::apply(module& p) const
{
std::string key = "require_std_shape";
for(auto ins : reverse_iterator_for(p))
{
auto&& attr = ins->get_operator().attributes();
if((attr.get(key, false)))
{
auto args = ins->inputs();
auto new_args = args;
std::transform(args.begin(), args.end(), new_args.begin(), [&](auto in) {
if(in->name() == "contiguous")
{
return in;
}
return p.insert_instruction(ins, make_op("contiguous"), in);
});
if(new_args != args)
{
p.replace_instruction(ins, ins->get_operator(), new_args);
}
}
}
auto last = std::prev(p.end());
for(auto ins : iterator_for(p))
{
// for last instruction that is NOT a return
if(ins->outputs().empty() and ins != last)
continue;
shape s = ins->get_shape();
if(not s.standard() and s.elements() != 0)
{
......
......@@ -34,7 +34,14 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
params += " -o " + out;
td.execute(compiler, params);
if(not launcher.empty())
{
td.execute(launcher, compiler + " " + params);
}
else
{
td.execute(compiler, params);
}
auto out_path = td.path / out;
if(not fs::exists(out_path))
......
......@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{};
std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
};
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
......@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
......@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
});
return f;
}
......
......@@ -61,7 +61,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu19;
auto mx19 = mm->add_instruction(relu19, mx18);
migraphx::op::pooling pooling20;
pooling20.mode = "max";
pooling20.mode = migraphx::op::pooling_mode::max;
pooling20.padding = {0, 0};
pooling20.stride = {2, 2};
pooling20.lengths = {3, 3};
......@@ -81,7 +81,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu24;
auto mx24 = mm->add_instruction(relu24, mx23);
migraphx::op::pooling pooling25;
pooling25.mode = "max";
pooling25.mode = migraphx::op::pooling_mode::max;
pooling25.padding = {0, 0};
pooling25.stride = {2, 2};
pooling25.lengths = {3, 3};
......@@ -129,7 +129,7 @@ migraphx::program alexnet(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu37;
auto mx37 = mm->add_instruction(relu37, mx36);
migraphx::op::pooling pooling38;
pooling38.mode = "max";
pooling38.mode = migraphx::op::pooling_mode::max;
pooling38.padding = {0, 0};
pooling38.stride = {2, 2};
pooling38.lengths = {3, 3};
......
......@@ -995,7 +995,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu492;
auto mx492 = mm->add_instruction(relu492, mx491);
migraphx::op::pooling pooling493;
pooling493.mode = "max";
pooling493.mode = migraphx::op::pooling_mode::max;
pooling493.padding = {0, 0};
pooling493.stride = {2, 2};
pooling493.lengths = {3, 3};
......@@ -1025,7 +1025,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu499;
auto mx499 = mm->add_instruction(relu499, mx498);
migraphx::op::pooling pooling500;
pooling500.mode = "max";
pooling500.mode = migraphx::op::pooling_mode::max;
pooling500.padding = {0, 0};
pooling500.stride = {2, 2};
pooling500.lengths = {3, 3};
......@@ -1103,7 +1103,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu518;
auto mx518 = mm->add_instruction(relu518, mx517);
migraphx::op::pooling pooling519;
pooling519.mode = "average";
pooling519.mode = migraphx::op::pooling_mode::average;
pooling519.padding = {1, 1};
pooling519.stride = {1, 1};
pooling519.lengths = {3, 3};
......@@ -1196,7 +1196,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu541;
auto mx541 = mm->add_instruction(relu541, mx540);
migraphx::op::pooling pooling542;
pooling542.mode = "average";
pooling542.mode = migraphx::op::pooling_mode::average;
pooling542.padding = {1, 1};
pooling542.stride = {1, 1};
pooling542.lengths = {3, 3};
......@@ -1289,7 +1289,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu564;
auto mx564 = mm->add_instruction(relu564, mx563);
migraphx::op::pooling pooling565;
pooling565.mode = "average";
pooling565.mode = migraphx::op::pooling_mode::average;
pooling565.padding = {1, 1};
pooling565.stride = {1, 1};
pooling565.lengths = {3, 3};
......@@ -1358,7 +1358,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu581;
auto mx581 = mm->add_instruction(relu581, mx580);
migraphx::op::pooling pooling582;
pooling582.mode = "max";
pooling582.mode = migraphx::op::pooling_mode::max;
pooling582.padding = {0, 0};
pooling582.stride = {2, 2};
pooling582.lengths = {3, 3};
......@@ -1475,7 +1475,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu610;
auto mx610 = mm->add_instruction(relu610, mx609);
migraphx::op::pooling pooling611;
pooling611.mode = "average";
pooling611.mode = migraphx::op::pooling_mode::average;
pooling611.padding = {1, 1};
pooling611.stride = {1, 1};
pooling611.lengths = {3, 3};
......@@ -1604,7 +1604,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu642;
auto mx642 = mm->add_instruction(relu642, mx641);
migraphx::op::pooling pooling643;
pooling643.mode = "average";
pooling643.mode = migraphx::op::pooling_mode::average;
pooling643.padding = {1, 1};
pooling643.stride = {1, 1};
pooling643.lengths = {3, 3};
......@@ -1733,7 +1733,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu674;
auto mx674 = mm->add_instruction(relu674, mx673);
migraphx::op::pooling pooling675;
pooling675.mode = "average";
pooling675.mode = migraphx::op::pooling_mode::average;
pooling675.padding = {1, 1};
pooling675.stride = {1, 1};
pooling675.lengths = {3, 3};
......@@ -1862,7 +1862,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu706;
auto mx706 = mm->add_instruction(relu706, mx705);
migraphx::op::pooling pooling707;
pooling707.mode = "average";
pooling707.mode = migraphx::op::pooling_mode::average;
pooling707.padding = {1, 1};
pooling707.stride = {1, 1};
pooling707.lengths = {3, 3};
......@@ -1955,7 +1955,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
migraphx::op::relu relu729;
auto mx729 = mm->add_instruction(relu729, mx728);
migraphx::op::pooling pooling730;
pooling730.mode = "max";
pooling730.mode = migraphx::op::pooling_mode::max;
pooling730.padding = {0, 0};
pooling730.stride = {2, 2};
pooling730.lengths = {3, 3};
......@@ -2066,7 +2066,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat757.axis = 1;
auto mx757 = mm->add_instruction(concat757, mx753, mx756);
migraphx::op::pooling pooling758;
pooling758.mode = "average";
pooling758.mode = migraphx::op::pooling_mode::average;
pooling758.padding = {1, 1};
pooling758.stride = {1, 1};
pooling758.lengths = {3, 3};
......@@ -2189,7 +2189,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat788.axis = 1;
auto mx788 = mm->add_instruction(concat788, mx784, mx787);
migraphx::op::pooling pooling789;
pooling789.mode = "average";
pooling789.mode = migraphx::op::pooling_mode::average;
pooling789.padding = {1, 1};
pooling789.stride = {1, 1};
pooling789.lengths = {3, 3};
......@@ -2210,7 +2210,7 @@ migraphx::program inceptionv3(unsigned batch) // NOLINT(readability-function-siz
concat793.axis = 1;
auto mx793 = mm->add_instruction(concat793, mx765, mx775, mx788, mx792);
migraphx::op::pooling pooling794;
pooling794.mode = "average";
pooling794.mode = migraphx::op::pooling_mode::average;
pooling794.padding = {0, 0};
pooling794.stride = {8, 8};
pooling794.lengths = {8, 8};
......
......@@ -505,8 +505,10 @@ struct roctx : command<roctx>
struct op : command<op>
{
bool show_ops = false;
std::string op_name{};
void parse(argument_parser& ap)
{
ap(op_name, {}, ap.metavar("<MIGraphX operator name>"));
ap(show_ops,
{"--list", "-l"},
ap.help("List all the operators of MIGraphX"),
......@@ -519,6 +521,12 @@ struct op : command<op>
for(const auto& name : get_operators())
std::cout << name << std::endl;
}
else
{
auto op = load_op(op_name);
std::cout << op_name << ": " << std::endl;
std::cout << to_pretty_json_string(op.to_value()) << std::endl;
}
}
};
......
......@@ -87,6 +87,6 @@ target get_target(bool gpu)
void compile_program(program& p, bool gpu) { p.compile(get_target(gpu)); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
......@@ -561,7 +561,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu269;
auto mx269 = mm->add_instruction(relu269, mx268);
migraphx::op::pooling pooling270;
pooling270.mode = "max";
pooling270.mode = migraphx::op::pooling_mode::max;
pooling270.padding = {1, 1};
pooling270.stride = {2, 2};
pooling270.lengths = {3, 3};
......@@ -1215,7 +1215,7 @@ migraphx::program resnet50(unsigned batch) // NOLINT(readability-function-size)
migraphx::op::relu relu438;
auto mx438 = mm->add_instruction(relu438, mx437);
migraphx::op::pooling pooling439;
pooling439.mode = "average";
pooling439.mode = migraphx::op::pooling_mode::average;
pooling439.padding = {0, 0};
pooling439.stride = {1, 1};
pooling439.lengths = {7, 7};
......
......@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue;
p.replace_instruction(ins, eq);
processed_ins.emplace(ins);
auto outputs = eq->outputs();
std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y);
});
......
......@@ -78,15 +78,16 @@ void eliminate_contiguous::apply(module& p) const
continue;
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
auto args = ins->inputs();
auto new_args = args;
auto mod_args = ins->module_inputs();
for(auto arg : ins->inputs())
{
if(arg->name() == op_name)
{
auto new_args = args;
auto prev = arg->inputs().front();
auto prev = arg->inputs().front();
replace(new_args, arg, prev);
if(try_compute_shape(ins, new_args, ins->module_inputs()))
if(try_compute_shape(ins, new_args, mod_args))
{
instruction::replace_argument(ins, arg, prev);
}
......
......@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
void eliminate_data_type::apply(module& m) const
{
static const std::vector<std::string> skip_op_names = {
"convert", "get_tuple_elem", "if", "loop", "roialign"};
static const std::vector<std::string> skip_op_names = {"convert",
"get_tuple_elem",
"if",
"loop",
"roialign",
"scatternd_add",
"scatternd_mul",
"scatternd_none"};
for(auto ins : iterator_for(m))
{
if(ins->name()[0] == '@')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment