Unverified Commit c9ffb38d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add output_alias and runs_on_offload_target flags for the custom ops (#1309)

Adds two methods for the custom_ops virtual class.

bool runs_on_offload_target(), if the custom op runs directly on the gpu then it should be set to true. in this case, custom op expects its parameters to reside in GPU memory and writes output to the GPU memory. If it is set to false then, custom op expects it's parameter to reside on the host and puts back the result into the host memory.

output_alias, if output of the custom op is aliasing the input buffer. i.e. interpreting the same input buffer with differnet shape and strides.

Update as_vector() in C++ API to handle non-standard shapes. It required exposing element_index to space_index conversion method for the shape class.
parent e19f78ae
...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N) ...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
struct square_custom_op final : migraphx::experimental_custom_op_base struct square_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "square_custom_op"; } virtual std::string name() const override { return "square_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{ {
...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base ...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
// is output argument, so it should be returned from compute method. // is output argument, so it should be returned from compute method.
auto* input_buffer = reinterpret_cast<float*>(inputs[0].data()); auto* input_buffer = reinterpret_cast<float*>(inputs[0].data());
auto* output_buffer = reinterpret_cast<float*>(inputs[1].data()); auto* output_buffer = reinterpret_cast<float*>(inputs[1].data());
size_t n_elements = inputs[0].get_shape().bytes() / sizeof(inputs[0].get_shape().type()); size_t n_elements = inputs[0].get_shape().elements();
MIGRAPHX_HIP_ASSERT(hipSetDevice(0)); MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
const unsigned blocks = 512; const unsigned blocks = 512;
const unsigned threads_per_block = 256; const unsigned threads_per_block = 256;
...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[]) ...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(s.bytes() / sizeof(s.type())); std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data())); pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp); auto results = p.eval(pp);
......
...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode, ...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
struct abs_custom_op final : migraphx::experimental_custom_op_base struct abs_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "abs_custom_op"; } virtual std::string name() const override { return "abs_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
......
...@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx) ...@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
struct sscal_custom_op final : migraphx::experimental_custom_op_base struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "sscal_custom_op"; } virtual std::string name() const override { return "sscal_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
...@@ -107,7 +114,7 @@ int main(int argc, const char* argv[]) ...@@ -107,7 +114,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(x_shape.bytes() / sizeof(x_shape.type())); std::vector<float> x_data(x_shape.elements());
std::vector<float> scale_data{-1}; std::vector<float> scale_data{-1};
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(x_shape, x_data.data())); pp.add("x", migraphx::argument(x_shape, x_data.data()));
......
...@@ -265,11 +265,18 @@ struct experimental_custom_op ...@@ -265,11 +265,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +291,23 @@ struct custom_operation ...@@ -284,6 +291,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
...@@ -613,9 +637,9 @@ struct migraphx_experimental_custom_op ...@@ -613,9 +637,9 @@ struct migraphx_experimental_custom_op
migraphx::shape output, migraphx::shape output,
std::vector<migraphx::argument> inputs) const std::vector<migraphx::argument> inputs) const
{ {
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr) if(compute_f == nullptr)
throw std::runtime_error("compute function is missing."); throw std::runtime_error("compute function is missing.");
std::remove_pointer_t<migraphx_argument_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_f(&out, auto api_error_result = compute_f(&out,
...@@ -637,9 +661,9 @@ struct migraphx_experimental_custom_op ...@@ -637,9 +661,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
std::remove_pointer_t<migraphx_shape_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out, auto api_error_result = compute_shape_f(&out,
...@@ -655,6 +679,49 @@ struct migraphx_experimental_custom_op ...@@ -655,6 +679,49 @@ struct migraphx_experimental_custom_op
} }
return (&out)->object; return (&out)->object;
} }
migraphx_experimental_custom_op_output_alias output_alias_f = nullptr;
std::vector<size_t> output_alias(std::vector<migraphx::shape> inputs) const
{
if(output_alias_f == nullptr)
throw std::runtime_error("output_alias function is missing.");
std::array<size_t, 1024> out;
std::remove_pointer_t<size_t*> out_size = 1024;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = output_alias_f(out.data(),
&out_size,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in output_alias of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return {out.begin(), out.begin() + out_size}; // cppcheck-suppress returnDanglingLifetime;
}
migraphx_experimental_custom_op_runs_on_offload_target runs_on_offload_target_f = nullptr;
bool runs_on_offload_target() const
{
if(runs_on_offload_target_f == nullptr)
throw std::runtime_error("runs_on_offload_target function is missing.");
std::remove_pointer_t<bool*> out;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = runs_on_offload_target_f(
&out, object_ptr.data, exception_msg.data(), exception_msg.size());
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in runs_on_offload_target of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return out;
}
}; };
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
...@@ -758,6 +825,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, ...@@ -758,6 +825,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).elements();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -791,6 +868,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -791,6 +868,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).index((i));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1879,6 +1966,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( ...@@ -1879,6 +1966,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_output_alias(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_output_alias input)
{
auto api_error_result = migraphx::try_([&] { (obj)->output_alias_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input)
{
auto api_error_result = migraphx::try_([&] { (obj)->runs_on_offload_target_f = (input); });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op) migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{ {
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph ...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
size_t exception_msg_size, size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_output_alias)(size_t* out,
size_t* out_size,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_runs_on_offload_target)(
bool* out, void* obj, char* exception_msg, size_t exception_msg_size);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -502,6 +515,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob ...@@ -502,6 +515,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <algorithm>
#include <cstring> #include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
#include <numeric>
#include <exception> #include <exception>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
this->set_handle(p, std::move(lifetime)); \ this->set_handle(p, std::move(lifetime)); \
} }
template <size_t N>
struct out_params
{
};
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -391,7 +398,22 @@ struct interface_base : Base ...@@ -391,7 +398,22 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_fp(Setter setter, F pf) void set_fp(Setter setter, F pf, out_params<2>)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter,
this->get_handle_ptr(),
[](auto out1, auto out2, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<2>{}, f, out1, out2, obj, xs...); },
ex_msg,
ex_msg_size);
});
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf, out_params<1>)
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
...@@ -405,11 +427,27 @@ struct interface_base : Base ...@@ -405,11 +427,27 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f) void set_fp(Setter setter, F pf, out_params<0>)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter,
this->get_handle_ptr(),
[](void* obj, char* ex_msg, size_t ex_msg_size, auto... xs) -> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<0>{}, f, obj, xs...); }, ex_msg, ex_msg_size);
});
}
template <class T, class Setter, class F, class Out>
void set_auto_fp(Setter setter, F f, Out nums)
{ {
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) { return set_fp<T>(
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...); setter,
}); [=](T& obj, auto out1, auto out2, auto... xs) {
auto_invoke(f, out1, out2, obj, auto_convert_param(rank<2>{}, xs)...);
},
nums);
} }
struct no_out_arg struct no_out_arg
...@@ -419,7 +457,7 @@ struct interface_base : Base ...@@ -419,7 +457,7 @@ struct interface_base : Base
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>> template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs) static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{ {
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...); f(reinterpret_cast<T*>(obj), no_out_arg{}, no_out_arg{}, xs...);
} }
template <class T, template <class T,
...@@ -430,17 +468,35 @@ struct interface_base : Base ...@@ -430,17 +468,35 @@ struct interface_base : Base
class = std::enable_if_t<std::is_void<X>{}>> class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs) static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{ {
f(*reinterpret_cast<T*>(obj), result, xs...); f(*reinterpret_cast<T*>(obj), result, no_out_arg{}, xs...);
}
template <class T,
class F,
class R1,
class R2,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<2>, F f, R1 result1, R2 result2, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result1, result2, xs...);
}
template <class F, class T1, class T2, class... Ts>
void auto_invoke(F f, T1* out1, T2* out2, Ts&&... xs)
{
auto_assign(rank<2>{}, out1, out2, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs) void auto_invoke(F f, T* out, no_out_arg, Ts&&... xs)
{ {
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...)); auto_assign(rank<1>{}, out, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs) void auto_invoke(F f, no_out_arg, no_out_arg, Ts&&... xs)
{ {
f(std::forward<Ts>(xs)...); f(std::forward<Ts>(xs)...);
} }
...@@ -469,7 +525,7 @@ struct interface_base : Base ...@@ -469,7 +525,7 @@ struct interface_base : Base
template <class T, class U> template <class T, class U>
void auto_assign(rank<0>, T* out, U x) void auto_assign(rank<0>, T* out, U x)
{ {
return *out = x; *out = x;
} }
template <class T, class U> template <class T, class U>
...@@ -477,12 +533,21 @@ struct interface_base : Base ...@@ -477,12 +533,21 @@ struct interface_base : Base
{ {
x.assign_to_handle(out); x.assign_to_handle(out);
} }
template <class T1, class T2, class U, class = std::enable_if_t<std::is_same<T2, size_t>{}>>
auto auto_assign(rank<2>, T1* out_ptr, T2* out_size, U x)
{
*out_size = std::min(*out_size, x.size());
std::copy_n(x.begin(), *out_size, out_ptr);
}
}; };
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \ #define MIGRAPHX_INTERFACE_LIFT(n_out, T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \ this->set_auto_fp<T>( \
[](T& x, auto... xs) { return x.name(xs...); }) &migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); }, \
out_params<n_out>{})
template <class Base, class T> template <class Base, class T>
using require_interface = using require_interface =
...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
size_t elements() const
{
size_t pout;
call(&migraphx_shape_elements, &pout, this->get_handle_ptr());
return pout;
}
size_t bytes() const size_t bytes() const
{ {
size_t pout; size_t pout;
...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
// map element index to space index
size_t index(size_t i) const
{
size_t result;
call(&migraphx_shape_index, &result, this->get_handle_ptr(), i);
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
template <typename T> template <typename T>
std::vector<T> as_vector() const std::vector<T> as_vector() const
{ {
size_t vector_len = this->get_shape().bytes() / sizeof(T); auto ss = this->get_shape();
T* buffer_ptr = reinterpret_cast<T*>(this->data()); auto num_elements = ss.elements();
return {buffer_ptr, buffer_ptr + vector_len}; std::vector<T> res(num_elements);
T* buffer_ptr = reinterpret_cast<T*>(this->data());
for(size_t i = 0; i < num_elements; i++)
{
res[i] = buffer_ptr[ss.index(i)];
}
return res;
} }
/// Generate an argument using random data /// Generate an argument using random data
...@@ -1255,7 +1341,10 @@ struct experimental_custom_op_base ...@@ -1255,7 +1341,10 @@ struct experimental_custom_op_base
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual ~experimental_custom_op_base() = default; virtual std::vector<size_t> output_alias(shapes) const { return {}; }
// TODO: Return target string instead of bool
virtual bool runs_on_offload_target() const = 0;
virtual ~experimental_custom_op_base() = default;
}; };
struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)> struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental_custom_op)>
...@@ -1267,8 +1356,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1267,8 +1356,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
obj, obj,
get_type_name(obj).c_str(), get_type_name(obj).c_str(),
obj.name().c_str()); obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute);
MIGRAPHX_INTERFACE_LIFT(2, T, experimental_custom_op, output_alias);
MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, runs_on_offload_target);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
...@@ -115,6 +115,7 @@ def shape(h): ...@@ -115,6 +115,7 @@ def shape(h):
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
...@@ -122,6 +123,7 @@ def shape(h): ...@@ -122,6 +123,7 @@ def shape(h):
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
@auto_handle() @auto_handle()
...@@ -450,4 +452,8 @@ def experimental_custom_op(h): ...@@ -450,4 +452,8 @@ def experimental_custom_op(h):
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
h.virtual('output_alias',
api.params(inputs='std::vector<migraphx::shape>'),
returns='std::vector<size_t>')
h.virtual('runs_on_offload_target', returns='bool')
h.method('register', invoke='migraphx::register_custom_op($@)') h.method('register', invoke='migraphx::register_custom_op($@)')
...@@ -26,6 +26,8 @@ ...@@ -26,6 +26,8 @@
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp> #include <migraphx/op/deconvolution.hpp>
...@@ -171,7 +173,8 @@ struct miopen_apply ...@@ -171,7 +173,8 @@ struct miopen_apply
init(); init();
for(auto it = mod->begin(); it != mod->end(); it++) for(auto it = mod->begin(); it != mod->end(); it++)
{ {
auto s = it->get_shape(); auto s = it->get_shape();
auto attrs = it->get_operator().attributes();
if(apply_map.count(it->name()) > 0) if(apply_map.count(it->name()) > 0)
{ {
check_shape(s, apply_map.at(it->name())(it)); check_shape(s, apply_map.at(it->name())(it));
...@@ -180,11 +183,37 @@ struct miopen_apply ...@@ -180,11 +183,37 @@ struct miopen_apply
{ {
check_shape(s, insert_precompile_op(it)); check_shape(s, insert_precompile_op(it));
} }
else if(attrs.contains("target"))
{
check_shape(s, insert_custom_op(it, attrs));
}
} }
copy_params(); copy_params();
} }
instruction_ref insert_custom_op(instruction_ref ins, const value& attrs) const
{
const auto& custom_op = ins->get_operator();
if(attrs.at("target") == "cpu")
{
auto s = ins->get_shape();
std::vector<instruction_ref> cpu_inputs;
auto inputs = ins->inputs();
auto output = inputs.back();
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) {
return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in);
});
cpu_inputs.front() =
mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs);
auto cpu_out = mod->insert_instruction(ins, custom_op, cpu_inputs);
auto gpu_out =
mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output);
return mod->replace_instruction(ins, gpu_out);
}
return ins;
}
instruction_ref insert_precompile_op(instruction_ref ins) const instruction_ref insert_precompile_op(instruction_ref ins) const
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
......
...@@ -43,6 +43,8 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base ...@@ -43,6 +43,8 @@ struct sigmoid_custom_op final : migraphx::experimental_custom_op_base
return inputs[1]; return inputs[1];
} }
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{ {
if(inputs.size() != 2) if(inputs.size() != 2)
...@@ -111,4 +113,45 @@ TEST_CASE(run_sigmoid_with_incorrect_shape) ...@@ -111,4 +113,45 @@ TEST_CASE(run_sigmoid_with_incorrect_shape)
"Error in compute_shape of: sigmoid_custom_op: op must have two inputs")); "Error in compute_shape of: sigmoid_custom_op: op must have two inputs"));
} }
struct identity_custom_op final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "identity_custom_op"; }
virtual migraphx::argument
compute(migraphx::context, migraphx::shape, migraphx::arguments inputs) const override
{
return inputs[0];
}
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(inputs.size() != 1)
{
throw std::runtime_error("Identity op must have only one input");
}
return inputs.back();
}
virtual std::vector<size_t> output_alias(migraphx::shapes) const override { return {0, 1}; }
};
TEST_CASE(run_custom_op_with_invalid_output_alias)
{
identity_custom_op i_op;
migraphx::register_experimental_custom_op(i_op);
auto op = migraphx::operation("identity_custom_op");
EXPECT(op.name() == "identity_custom_op");
migraphx::program p;
migraphx::shape s{migraphx_shape_float_type, {12}};
migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s);
auto i_ins = m.add_instruction(migraphx::operation("identity_custom_op"), {x});
migraphx_test_private_disable_exception_catch(true);
EXPECT(test::throws<std::exception>(
[&] { p.compile(migraphx::target("ref")); },
"Currently, CustomOps in MIGraphX only supports one output_alias"));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -24,40 +24,89 @@ ...@@ -24,40 +24,89 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> #include <migraphx/migraphx.hpp>
#include <numeric>
#include <stdexcept> #include <stdexcept>
#include "test.hpp" #include "test.hpp"
#define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess)) #define MIGRAPHX_HIP_ASSERT(x) (EXPECT(x == hipSuccess))
struct simple_custom_op final : migraphx::experimental_custom_op_base
struct half_copy_host final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "simple_custom_op"; } virtual std::string name() const override { return "half_copy_host"; }
virtual bool runs_on_offload_target() const override { return false; }
virtual migraphx::argument virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{ {
// sets first half size_bytes of the input 0, and rest of the half bytes are copied. // This custom op simply sets first half size_bytes of the input to 0, and rest of the half
int* h_output = nullptr; // bytes are copied. for this custom_op, it does its computation on the host. Therefore,
auto* d_output = reinterpret_cast<int*>(inputs[0].data()); // `runs_on_offload_target()` is set to false. MIGraphX would inject necessary buffer copies
auto input_bytes = inputs[0].get_shape().bytes(); // to and from GPU to Host based on `runs_on_offload_targe()` flag for input buffers as well
auto* output_ptr = inputs[1].data(); // as the output buffers
auto copy_bytes = input_bytes / 2; auto* input_buffer_ptr = inputs[0].data();
auto* output_buffer_ptr = inputs[1].data();
auto input_bytes = inputs[0].get_shape().bytes();
auto copy_bytes = input_bytes / 2;
MIGRAPHX_HIP_ASSERT(hipSetDevice(0)); MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
MIGRAPHX_HIP_ASSERT(hipHostMalloc(&h_output, input_bytes)); MIGRAPHX_HIP_ASSERT(hipMemcpyAsync(output_buffer_ptr,
MIGRAPHX_HIP_ASSERT(hipMemcpyAsync( input_buffer_ptr,
h_output, d_output, input_bytes, hipMemcpyDeviceToHost, ctx.get_queue<hipStream_t>())); input_bytes,
hipMemcpyHostToHost,
ctx.get_queue<hipStream_t>()));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipMemset(output_buffer_ptr, 0, copy_bytes));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize()); MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipMemset(h_output, 0, copy_bytes)); return inputs[1];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(not inputs[0].standard() or not inputs[1].standard())
{
throw std::runtime_error("Input args must be standard shaped");
}
if(inputs.size() != 2)
{
throw std::runtime_error("number of inputs must be 2");
}
return inputs.back();
}
};
struct half_copy_device final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "half_copy_device"; }
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{
// This custom op simply sets first half size_bytes of the input to 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "GPU". Therefore,
// `runs_on_offload_target()` is set to "true".
auto* input_buffer_ptr = inputs[0].data();
auto* output_buffer_ptr = inputs[1].data();
auto input_bytes = inputs[0].get_shape().bytes();
auto copy_bytes = input_bytes / 2;
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
MIGRAPHX_HIP_ASSERT(hipMemcpyAsync(output_buffer_ptr,
input_buffer_ptr,
input_bytes,
hipMemcpyDeviceToDevice,
ctx.get_queue<hipStream_t>()));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize()); MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipMemcpy(output_ptr, h_output, input_bytes, hipMemcpyHostToDevice)); MIGRAPHX_HIP_ASSERT(hipMemset(output_buffer_ptr, 0, copy_bytes));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize()); MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
MIGRAPHX_HIP_ASSERT(hipHostFree(h_output));
return inputs[1]; return inputs[1];
} }
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{ {
if(not inputs[0].standard()) if(not inputs[0].standard() or not inputs[1].standard())
{ {
throw std::runtime_error("first arg must be standard shaped"); throw std::runtime_error("Input args must be standard shaped");
} }
if(inputs.size() != 2) if(inputs.size() != 2)
{ {
...@@ -67,36 +116,208 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base ...@@ -67,36 +116,208 @@ struct simple_custom_op final : migraphx::experimental_custom_op_base
} }
}; };
TEST_CASE(run_simple_custom_op) // overwrites input buffer
struct half_copy_device_same_buffer final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "half_copy_device_same_buffer"; }
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument
compute(migraphx::context, migraphx::shape, migraphx::arguments inputs) const override
{
// This custom op simply sets first half size_bytes of the input 0, and rest of the half
// bytes are copied. for this custom_op, it does its computation on the "device". Therefore,
// `runs_on_offload_target()` is set to "true"
auto* buffer_ptr = inputs[0].data();
auto input_bytes = inputs[0].get_shape().bytes();
auto copy_bytes = input_bytes / 2;
MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
MIGRAPHX_HIP_ASSERT(hipMemset(buffer_ptr, 0, copy_bytes));
MIGRAPHX_HIP_ASSERT(hipDeviceSynchronize());
return inputs[0];
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(not inputs[0].standard())
{
throw std::runtime_error("Input arg must be standard shaped");
}
return inputs.front();
}
};
TEST_CASE(register_half_copy_op)
{
half_copy_host hch;
migraphx::register_experimental_custom_op(hch);
auto op = migraphx::operation("half_copy_host");
EXPECT(op.name() == "half_copy_host");
half_copy_device hcd;
migraphx::register_experimental_custom_op(hcd);
op = migraphx::operation("half_copy_device");
EXPECT(op.name() == "half_copy_device");
half_copy_device_same_buffer hcdsb;
migraphx::register_experimental_custom_op(hcdsb);
op = migraphx::operation("half_copy_device_same_buffer");
EXPECT(op.name() == "half_copy_device_same_buffer");
}
TEST_CASE(half_copy_custom_op_test)
{
auto run_test_prog = [](const std::string& op_name, bool buffer_alloc) {
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape s{migraphx_shape_float_type, {4, 3}};
auto x = m.add_parameter("x", s);
migraphx::instructions inputs = {x};
if(buffer_alloc)
{
auto alloc = m.add_allocation(s);
inputs = {x, alloc};
}
auto half_copy_ins = m.add_instruction(migraphx::operation(op_name.c_str()), inputs);
m.add_return({half_copy_ins});
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
std::vector<float> x_data(12);
std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp);
auto result = results[0];
auto result_vec = result.as_vector<float>();
std::vector<float> expected_result(12, 0);
std::iota(expected_result.begin() + 6, expected_result.end(), 6);
EXPECT(bool{result == migraphx::argument(s, expected_result.data())});
};
// register all the ops
half_copy_host hch;
migraphx::register_experimental_custom_op(hch);
half_copy_device hcd;
migraphx::register_experimental_custom_op(hcd);
half_copy_device_same_buffer hcdsb;
migraphx::register_experimental_custom_op(hcdsb);
std::vector<std::pair<std::string, bool>> tests_config = {
{"half_copy_host", true},
{"half_copy_device", true},
{"half_copy_device_same_buffer", false}};
for(const auto& i : tests_config)
{
run_test_prog(i.first, i.second);
}
}
struct stride_two final : migraphx::experimental_custom_op_base
{
virtual std::string name() const override { return "stride_two"; }
virtual migraphx::argument
compute(migraphx::context, migraphx::shape out_shape, migraphx::arguments inputs) const override
{
return {out_shape, inputs[0].data()};
}
virtual migraphx::shape compute_shape(migraphx::shapes inputs) const override
{
if(inputs.size() != 1)
{
throw std::runtime_error("stride_two op must have only one input argument");
};
if(not inputs[0].standard())
{
throw std::runtime_error("stride_two op only works on the standard input shapes");
}
migraphx::shape input_s = inputs[0];
std::vector<size_t> dims = input_s.lengths();
std::vector<size_t> new_dims;
std::vector<size_t> strides = input_s.strides();
std::vector<size_t> new_strides;
std::for_each(dims.begin(), dims.end(), [&](auto i) { new_dims.push_back(i / 2); });
std::for_each(
strides.begin(), strides.end(), [&](auto i) { new_strides.push_back(i * 2); });
migraphx::shape output_shape{input_s.type(), new_dims, new_strides};
return output_shape;
}
virtual bool runs_on_offload_target() const override { return true; }
virtual std::vector<size_t> output_alias(migraphx::shapes) const override { return {0}; };
};
TEST_CASE(stride_two_custom_op_test)
{ {
simple_custom_op simple_op; stride_two st;
migraphx::register_experimental_custom_op(simple_op); migraphx::register_experimental_custom_op(st);
migraphx::program p;
migraphx::module m = p.get_main_module();
migraphx::shape s{migraphx_shape_float_type, {4, 4, 4}};
auto x = m.add_parameter("x", s);
auto stride_two_ins = m.add_instruction(migraphx::operation("stride_two"), {x});
m.add_return({stride_two_ins});
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
std::vector<float> x_data(64);
std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp);
auto result = results[0];
auto result_vec = result.as_vector<float>();
std::vector<float> expected_result = {0, 2, 8, 10, 32, 34, 40, 42};
EXPECT(result_vec == expected_result);
}
TEST_CASE(custom_op_with_pre_and_post_subgraph_test)
{
half_copy_host hco;
migraphx::register_experimental_custom_op(hco);
stride_two st;
migraphx::register_experimental_custom_op(st);
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx_shape_int32_type, {4, 3}}; migraphx::shape s{migraphx_shape_float_type, {4, 6}};
migraphx::shape trans_shape{migraphx_shape_int32_type, {3, 4}};
migraphx::module m = p.get_main_module(); migraphx::module m = p.get_main_module();
auto x = m.add_parameter("x", s); auto x = m.add_parameter("x", s);
auto neg = m.add_instruction(migraphx::operation("neg"), x); // pre-subgraph
auto alloc = m.add_allocation(trans_shape); auto neg_ins = m.add_instruction(migraphx::operation("neg"), x);
auto neg_trans = auto trans_ins =
m.add_instruction(migraphx::operation("transpose", "{permutation: [1, 0]}"), {neg}); m.add_instruction(migraphx::operation("transpose", "{permutation: [1, 0]}"), {neg_ins});
auto neg_cont = m.add_instruction(migraphx::operation("contiguous"), {neg_trans}); auto cont_ins = m.add_instruction(migraphx::operation("contiguous"), {trans_ins});
auto custom_kernel = // custom_op
m.add_instruction(migraphx::operation("simple_custom_op"), {neg_cont, alloc}); migraphx::shape trans_shape{migraphx_shape_float_type, {6, 4}};
auto relu = m.add_instruction(migraphx::operation("relu"), custom_kernel); auto alloc = m.add_allocation(trans_shape);
m.add_return({relu}); auto half_copy_ins =
m.add_instruction(migraphx::operation("half_copy_host"), {cont_ins, alloc});
// post-subgraph
auto abs_ins = m.add_instruction(migraphx::operation("abs"), {half_copy_ins});
// another custom_op
auto stride_two_ins = m.add_instruction(migraphx::operation("stride_two"), {abs_ins});
// post-subgraph
auto relu_ins = m.add_instruction(migraphx::operation("relu"), {stride_two_ins});
m.add_return({relu_ins});
migraphx::compile_options options; migraphx::compile_options options;
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<int> x_data(12, -3); std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data())); pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp); auto results = p.eval(pp);
auto result = results[0]; auto result = results[0];
auto result_vec = result.as_vector<int>(); auto result_vec = result.as_vector<float>();
std::vector<int> expected_result(12, 0); std::vector<float> expected_result = {0, 0, 0, 0, 4, 16};
std::fill(expected_result.begin() + 6, expected_result.end(), 3); EXPECT(bool{result == migraphx::argument(migraphx::shape{migraphx_shape_float_type, {3, 2}},
EXPECT(bool{result == migraphx::argument(trans_shape, expected_result.data())}); expected_result.data())});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -82,10 +82,10 @@ TEST_CASE(if_pl_test) ...@@ -82,10 +82,10 @@ TEST_CASE(if_pl_test)
migraphx::program_parameters pp; migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes(); auto param_shapes = p.get_parameter_shapes();
auto xs = param_shapes["x"]; auto xs = param_shapes["x"];
std::vector<float> xd(xs.bytes() / sizeof(float), 1.0); std::vector<float> xd(xs.elements(), 1.0);
pp.add("x", migraphx::argument(xs, xd.data())); pp.add("x", migraphx::argument(xs, xd.data()));
auto ys = param_shapes["y"]; auto ys = param_shapes["y"];
std::vector<float> yd(ys.bytes() / sizeof(float), 2.0); std::vector<float> yd(ys.elements(), 2.0);
pp.add("y", migraphx::argument(ys, yd.data())); pp.add("y", migraphx::argument(ys, yd.data()));
char ccond = cond; char ccond = cond;
pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond)); pp.add("cond", migraphx::argument(param_shapes["cond"], &ccond));
......
...@@ -21,7 +21,10 @@ ...@@ -21,7 +21,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import string, sys, re, runpy import string
import sys
import re
import runpy
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -308,18 +311,39 @@ class Parameter: ...@@ -308,18 +311,39 @@ class Parameter:
return self.substitute('${type} ${name}', prefix=prefix) return self.substitute('${type} ${name}', prefix=prefix)
def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]: def virtual_output_args(self, prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'&{prefix}{n}'.format(prefix=prefix or '', n=n) decl_list: List[str] = []
for t, n in self.cparams container = (container_type == "std::vector"
] or container_type == "vector")
for t, n, in self.cparams:
if not decl_list and container:
decl_list.append('{prefix}{n}.data()'.format(prefix=prefix
or '',
n=n))
else:
decl_list.append('&{prefix}{n}'.format(prefix=prefix or '',
n=n))
return decl_list
def virtual_output_declarations(self, def virtual_output_declarations(self,
prefix: Optional[str] = None) -> List[str]: prefix: Optional[str] = None) -> List[str]:
return [ container_type = self.type.remove_generic().basic().str()
'std::remove_pointer_t<{type}> {prefix}{n};'.format( container = (container_type == "std::vector"
type=Type(t).str(), prefix=prefix or '', n=n) or container_type == "vector")
for t, n in self.cparams decl_list: List[str] = []
] for t, n, in self.cparams:
if not decl_list and container:
inner_t = self.type.inner_type()
if inner_t:
decl_list.append(
'std::array<{inner_t}, 1024> {prefix}{n};'.format(
inner_t=inner_t.str(), prefix=prefix or '', n=n))
else:
decl_list.append(
'std::remove_pointer_t<{type}> {prefix}{n}'.format(
type=Type(t).str(), prefix=prefix or '', n=n))
decl_list[-1] += '=1024;' if container else ';'
return decl_list
def virtual_output(self, prefix: Optional[str] = None) -> str: def virtual_output(self, prefix: Optional[str] = None) -> str:
write = self.virtual_write write = self.virtual_write
...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str: ...@@ -694,9 +718,14 @@ def generate_cpp_header() -> str:
[c.generate() for c in cpp_classes]) [c.generate() for c in cpp_classes])
def cwrap(name: str) -> Callable: c_type_map: Dict[str, Type] = {}
def cwrap(name: str, c_type: Optional[str] = None) -> Callable:
def with_cwrap(f): def with_cwrap(f):
type_map[name] = f type_map[name] = f
if c_type:
c_type_map[name] = Type(c_type)
@wraps(f) @wraps(f)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -917,6 +946,9 @@ def vector_c_wrap(p: Parameter) -> None:
# Not a generic type # Not a generic type
if not inner: if not inner:
return return
if inner.str() in c_type_map:
inner = c_type_map[inner.str()]
t = inner.add_pointer() t = inner.add_pointer()
if p.type.is_reference(): if p.type.is_reference():
if p.type.is_const(): if p.type.is_const():
...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -927,6 +959,12 @@ def vector_c_wrap(p: Parameter) -> None:
p.add_size_param() p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr', p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer') 'Null pointer')
elif p.virtual:
p.add_param(t)
p.add_size_param()
p.bad_param('${name} == nullptr or ${size} == nullptr',
'Null pointer')
p.virtual_write = '{${name}.begin(), ${name}.begin()+${size}}; // cppcheck-suppress returnDanglingLifetime'
else: else:
p.add_param(t) p.add_param(t)
p.bad_param('${name} == nullptr', 'Null pointer') p.bad_param('${name} == nullptr', 'Null pointer')
...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None: ...@@ -946,7 +984,7 @@ def vector_c_wrap(p: Parameter) -> None:
p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})'] p.write = ['std::copy(${result}.begin(), ${result}.end(), ${name})']
@cwrap('std::string') @cwrap('std::string', 'char*')
def string_c_wrap(p: Parameter) -> None: def string_c_wrap(p: Parameter) -> None:
t = Type('char*') t = Type('char*')
if p.returns: if p.returns:
...@@ -1061,9 +1099,9 @@ struct ${ctype} { ...@@ -1061,9 +1099,9 @@ struct ${ctype} {
c_api_virtual_impl = Template(''' c_api_virtual_impl = Template('''
${return_type} ${name}(${params}) const ${return_type} ${name}(${params}) const
{ {
${output_decls}
if (${fname} == nullptr) if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing."); throw std::runtime_error("${name} function is missing.");
${output_decls}
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\\0'; exception_msg.front() = '\\0';
auto api_error_result = ${fname}(${args}); auto api_error_result = ${fname}(${args});
......
...@@ -265,11 +265,18 @@ struct experimental_custom_op ...@@ -265,11 +265,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +291,23 @@ struct custom_operation ...@@ -284,6 +291,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment