Commit 180dc7a0 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into jit-unroll-stride

parents e535f7ef f7d987ba
...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N) ...@@ -45,6 +45,13 @@ __global__ void vector_square(T* C_d, const T* A_d, size_t N)
struct square_custom_op final : migraphx::experimental_custom_op_base struct square_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "square_custom_op"; } virtual std::string name() const override { return "square_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument virtual migraphx::argument
compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override compute(migraphx::context ctx, migraphx::shape, migraphx::arguments inputs) const override
{ {
...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base ...@@ -54,7 +61,7 @@ struct square_custom_op final : migraphx::experimental_custom_op_base
// is output argument, so it should be returned from compute method. // is output argument, so it should be returned from compute method.
auto* input_buffer = reinterpret_cast<float*>(inputs[0].data()); auto* input_buffer = reinterpret_cast<float*>(inputs[0].data());
auto* output_buffer = reinterpret_cast<float*>(inputs[1].data()); auto* output_buffer = reinterpret_cast<float*>(inputs[1].data());
size_t n_elements = inputs[0].get_shape().bytes() / sizeof(inputs[0].get_shape().type()); size_t n_elements = inputs[0].get_shape().elements();
MIGRAPHX_HIP_ASSERT(hipSetDevice(0)); MIGRAPHX_HIP_ASSERT(hipSetDevice(0));
const unsigned blocks = 512; const unsigned blocks = 512;
const unsigned threads_per_block = 256; const unsigned threads_per_block = 256;
...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[]) ...@@ -103,7 +110,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(s.bytes() / sizeof(s.type())); std::vector<float> x_data(s.elements());
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(s, x_data.data())); pp.add("x", migraphx::argument(s, x_data.data()));
auto results = p.eval(pp); auto results = p.eval(pp);
......
...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode, ...@@ -93,6 +93,13 @@ inline auto make_activation_descriptor(miopenActivationMode_t mode,
struct abs_custom_op final : migraphx::experimental_custom_op_base struct abs_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "abs_custom_op"; } virtual std::string name() const override { return "abs_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
......
...@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx) ...@@ -51,6 +51,13 @@ rocblas_handle create_rocblas_handle_ptr(migraphx::context& ctx)
struct sscal_custom_op final : migraphx::experimental_custom_op_base struct sscal_custom_op final : migraphx::experimental_custom_op_base
{ {
virtual std::string name() const override { return "sscal_custom_op"; } virtual std::string name() const override { return "sscal_custom_op"; }
// flag to identify whether custom op runs on the GPU or on the host.
// Based on this flag MIGraphX would inject necessary copies to and from GPU for the input and
// output buffers as necessary. Therefore if custom_op runs on GPU then it can assume its input
// buffers are in GPU memory, and similarly for the host
virtual bool runs_on_offload_target() const override { return true; }
virtual migraphx::argument compute(migraphx::context ctx, virtual migraphx::argument compute(migraphx::context ctx,
migraphx::shape output_shape, migraphx::shape output_shape,
migraphx::arguments args) const override migraphx::arguments args) const override
...@@ -107,7 +114,7 @@ int main(int argc, const char* argv[]) ...@@ -107,7 +114,7 @@ int main(int argc, const char* argv[])
options.set_offload_copy(); options.set_offload_copy();
p.compile(migraphx::target("gpu"), options); p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp; migraphx::program_parameters pp;
std::vector<float> x_data(x_shape.bytes() / sizeof(x_shape.type())); std::vector<float> x_data(x_shape.elements());
std::vector<float> scale_data{-1}; std::vector<float> scale_data{-1};
std::iota(x_data.begin(), x_data.end(), 0); std::iota(x_data.begin(), x_data.end(), 0);
pp.add("x", migraphx::argument(x_shape, x_data.data())); pp.add("x", migraphx::argument(x_shape, x_data.data()));
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/execution_environment.hpp>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names) ...@@ -166,6 +167,13 @@ void set_output_names(tf_options& options, std::vector<const char*> names)
options.output_node_names = std::vector<std::string>(names.begin(), names.end()); options.output_node_names = std::vector<std::string>(names.begin(), names.end());
} }
std::vector<argument>
run_async(program& p, const parameter_map& params, void* s, std::string_view name)
{
execution_environment exec_env{any_ptr(s, name), true};
return p.eval(params, exec_env);
}
template <class Value> template <class Value>
std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m) std::vector<const char*> get_names(const std::unordered_map<std::string, Value>& m)
{ {
...@@ -265,11 +273,18 @@ struct experimental_custom_op ...@@ -265,11 +273,18 @@ struct experimental_custom_op
template <class CustomOp> template <class CustomOp>
struct custom_operation struct custom_operation
{ {
template <class Self, class F> template <class Self, class F>
static auto reflect(Self&, F) static auto reflect(Self&, F)
{ {
return pack(); return pack();
} }
value attributes() const
{
return {{"custom_op", true}, {"target", op.runs_on_offload_target() ? "gpu" : "cpu"}};
}
CustomOp op; CustomOp op;
std::string name() const { return op.xobject.name; } std::string name() const { return op.xobject.name; }
...@@ -284,6 +299,23 @@ struct custom_operation ...@@ -284,6 +299,23 @@ struct custom_operation
{ {
return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs)); return op.compute(std::move(ctx), std::move(output_shape), std::move(inputs));
} }
std::ptrdiff_t output_alias(std::vector<shape> inputs) const
{
auto alias_vec = op.output_alias(std::move(inputs));
// TODO: For now, only support one output alias
if(alias_vec.empty())
{
return -1;
}
if(alias_vec.size() > 1)
{
MIGRAPHX_THROW("Currently, CustomOps in MIGraphX only supports one output_alias");
}
return alias_vec.front();
}
bool runs_on_offload_target() const { return op.runs_on_offload_target(); }
}; };
template <class CustomOp> template <class CustomOp>
...@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op ...@@ -613,9 +645,9 @@ struct migraphx_experimental_custom_op
migraphx::shape output, migraphx::shape output,
std::vector<migraphx::argument> inputs) const std::vector<migraphx::argument> inputs) const
{ {
std::remove_pointer_t<migraphx_argument_t> out;
if(compute_f == nullptr) if(compute_f == nullptr)
throw std::runtime_error("compute function is missing."); throw std::runtime_error("compute function is missing.");
std::remove_pointer_t<migraphx_argument_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_f(&out, auto api_error_result = compute_f(&out,
...@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op ...@@ -637,9 +669,9 @@ struct migraphx_experimental_custom_op
migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr; migraphx_experimental_custom_op_compute_shape compute_shape_f = nullptr;
migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const migraphx::shape compute_shape(std::vector<migraphx::shape> inputs) const
{ {
std::remove_pointer_t<migraphx_shape_t> out;
if(compute_shape_f == nullptr) if(compute_shape_f == nullptr)
throw std::runtime_error("compute_shape function is missing."); throw std::runtime_error("compute_shape function is missing.");
std::remove_pointer_t<migraphx_shape_t> out;
std::array<char, 256> exception_msg; std::array<char, 256> exception_msg;
exception_msg.front() = '\0'; exception_msg.front() = '\0';
auto api_error_result = compute_shape_f(&out, auto api_error_result = compute_shape_f(&out,
...@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op ...@@ -655,6 +687,49 @@ struct migraphx_experimental_custom_op
} }
return (&out)->object; return (&out)->object;
} }
migraphx_experimental_custom_op_output_alias output_alias_f = nullptr;
std::vector<size_t> output_alias(std::vector<migraphx::shape> inputs) const
{
if(output_alias_f == nullptr)
throw std::runtime_error("output_alias function is missing.");
std::array<size_t, 1024> out;
std::remove_pointer_t<size_t*> out_size = 1024;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = output_alias_f(out.data(),
&out_size,
object_ptr.data,
exception_msg.data(),
exception_msg.size(),
object_cast<migraphx_shapes_t>(&(inputs)));
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in output_alias of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return {out.begin(), out.begin() + out_size}; // cppcheck-suppress returnDanglingLifetime;
}
migraphx_experimental_custom_op_runs_on_offload_target runs_on_offload_target_f = nullptr;
bool runs_on_offload_target() const
{
if(runs_on_offload_target_f == nullptr)
throw std::runtime_error("runs_on_offload_target function is missing.");
std::remove_pointer_t<bool*> out;
std::array<char, 256> exception_msg;
exception_msg.front() = '\0';
auto api_error_result = runs_on_offload_target_f(
&out, object_ptr.data, exception_msg.data(), exception_msg.size());
if(api_error_result != migraphx_status_success)
{
const std::string exception_str(exception_msg.data());
throw std::runtime_error("Error in runs_on_offload_target of: " +
std::string(object_ptr.obj_typename) + ": " + exception_str);
}
return out;
}
}; };
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
...@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, ...@@ -758,6 +833,16 @@ extern "C" migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).elements();
});
return api_error_result;
}
extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { auto api_error_result = migraphx::try_([&] {
...@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha ...@@ -791,6 +876,16 @@ extern "C" migraphx_status migraphx_shape_standard(bool* out, const_migraphx_sha
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i)
{
auto api_error_result = migraphx::try_([&] {
if(shape == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter shape: Null pointer");
*out = (shape->object).index((i));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument) extern "C" migraphx_status migraphx_argument_destroy(migraphx_argument_t argument)
{ {
auto api_error_result = migraphx::try_([&] { destroy((argument)); }); auto api_error_result = migraphx::try_([&] { destroy((argument)); });
...@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -1348,6 +1443,23 @@ extern "C" migraphx_status migraphx_program_run(migraphx_arguments_t* out,
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params,
void* s,
const char* name)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
if(params == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter params: Null pointer");
*out = allocate<migraphx_arguments_t>(
migraphx::run_async((program->object), (params->object), (s), (name)));
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x) migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x)
{ {
...@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape( ...@@ -1879,6 +1991,22 @@ extern "C" migraphx_status migraphx_experimental_custom_op_set_compute_shape(
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status
migraphx_experimental_custom_op_set_output_alias(migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_output_alias input)
{
auto api_error_result = migraphx::try_([&] { (obj)->output_alias_f = (input); });
return api_error_result;
}
extern "C" migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input)
{
auto api_error_result = migraphx::try_([&] { (obj)->runs_on_offload_target_f = (input); });
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op) migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op)
{ {
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph ...@@ -144,6 +143,16 @@ typedef migraphx_status (*migraphx_experimental_custom_op_compute_shape)(migraph
size_t exception_msg_size, size_t exception_msg_size,
migraphx_shapes_t inputs); migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_output_alias)(size_t* out,
size_t* out_size,
void* obj,
char* exception_msg,
size_t exception_msg_size,
migraphx_shapes_t inputs);
typedef migraphx_status (*migraphx_experimental_custom_op_runs_on_offload_target)(
bool* out, void* obj, char* exception_msg, size_t exception_msg_size);
typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input); typedef migraphx_status (*migraphx_experimental_custom_op_copy)(void** out, void* input);
typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input); typedef migraphx_status (*migraphx_experimental_custom_op_delete)(void* input);
...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap ...@@ -175,6 +184,8 @@ migraphx_shape_strides(const size_t** out, size_t* out_size, const_migraphx_shap
migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_type(migraphx_shape_datatype_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_elements(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_bytes(size_t* out, const_migraphx_shape_t shape);
migraphx_status migraphx_status
...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha ...@@ -182,6 +193,8 @@ migraphx_shape_equal(bool* out, const_migraphx_shape_t shape, const_migraphx_sha
migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape); migraphx_status migraphx_shape_standard(bool* out, const_migraphx_shape_t shape);
migraphx_status migraphx_shape_index(size_t* out, const_migraphx_shape_t shape, size_t i);
migraphx_status migraphx_argument_destroy(migraphx_argument_t argument); migraphx_status migraphx_argument_destroy(migraphx_argument_t argument);
migraphx_status migraphx_argument_assign_to(migraphx_argument_t output, migraphx_status migraphx_argument_assign_to(migraphx_argument_t output,
...@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -344,6 +357,12 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_program_t program, migraphx_program_t program,
migraphx_program_parameters_t params); migraphx_program_parameters_t params);
migraphx_status migraphx_program_run_async(migraphx_arguments_t* out,
migraphx_program_t program,
migraphx_program_parameters_t params,
void* s,
const char* name);
migraphx_status migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
...@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob ...@@ -502,6 +521,13 @@ migraphx_experimental_custom_op_set_compute(migraphx_experimental_custom_op_t ob
migraphx_status migraphx_experimental_custom_op_set_compute_shape( migraphx_status migraphx_experimental_custom_op_set_compute_shape(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input); migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_compute_shape input);
migraphx_status migraphx_experimental_custom_op_set_output_alias(
migraphx_experimental_custom_op_t obj, migraphx_experimental_custom_op_output_alias input);
migraphx_status migraphx_experimental_custom_op_set_runs_on_offload_target(
migraphx_experimental_custom_op_t obj,
migraphx_experimental_custom_op_runs_on_offload_target input);
migraphx_status migraphx_status
migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_register(migraphx_experimental_custom_op_t experimental_custom_op);
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
#define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP #define MIGRAPHX_GUARD_API_RTGLIB_MIGRAPHX_HPP
#include "migraphx.h" #include "migraphx.h"
#include <algorithm>
#include <cstring> #include <cstring>
#include <initializer_list> #include <initializer_list>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <memory> #include <memory>
#include <numeric>
#include <exception> #include <exception>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -340,6 +342,11 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
this->set_handle(p, std::move(lifetime)); \ this->set_handle(p, std::move(lifetime)); \
} }
template <size_t N>
struct out_params
{
};
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -391,7 +398,22 @@ struct interface_base : Base ...@@ -391,7 +398,22 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_fp(Setter setter, F pf) void set_fp(Setter setter, F pf, out_params<2>)
{
static F f = pf;
(void)f; // avoid warning on gcc
call(setter,
this->get_handle_ptr(),
[](auto out1, auto out2, void* obj, char* ex_msg, size_t ex_msg_size, auto... xs)
-> migraphx_status {
return try_([&] { call_cast_arg<T>(rank<2>{}, f, out1, out2, obj, xs...); },
ex_msg,
ex_msg_size);
});
}
template <class T, class Setter, class F>
void set_fp(Setter setter, F pf, out_params<1>)
{ {
static F f = pf; static F f = pf;
(void)f; // avoid warning on gcc (void)f; // avoid warning on gcc
...@@ -405,13 +427,29 @@ struct interface_base : Base ...@@ -405,13 +427,29 @@ struct interface_base : Base
} }
template <class T, class Setter, class F> template <class T, class Setter, class F>
void set_auto_fp(Setter setter, F f) void set_fp(Setter setter, F pf, out_params<0>)
{ {
return set_fp<T>(setter, [=](T& obj, auto out, auto... xs) { static F f = pf;
auto_invoke(f, out, obj, auto_convert_param(rank<2>{}, xs)...); (void)f; // avoid warning on gcc
call(setter,
this->get_handle_ptr(),
[](void* obj, char* ex_msg, size_t ex_msg_size, auto... xs) -> migraphx_status {
return try_(
[&] { call_cast_arg<T>(rank<0>{}, f, obj, xs...); }, ex_msg, ex_msg_size);
}); });
} }
template <class T, class Setter, class F, class Out>
void set_auto_fp(Setter setter, F f, Out nums)
{
return set_fp<T>(
setter,
[=](T& obj, auto out1, auto out2, auto... xs) {
auto_invoke(f, out1, out2, obj, auto_convert_param(rank<2>{}, xs)...);
},
nums);
}
struct no_out_arg struct no_out_arg
{ {
}; };
...@@ -419,7 +457,7 @@ struct interface_base : Base ...@@ -419,7 +457,7 @@ struct interface_base : Base
template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>> template <class T, class F, class X, class... Xs, class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs) static void call_cast_arg(rank<0>, F f, X* obj, Xs... xs)
{ {
f(reinterpret_cast<T*>(obj), no_out_arg{}, xs...); f(reinterpret_cast<T*>(obj), no_out_arg{}, no_out_arg{}, xs...);
} }
template <class T, template <class T,
...@@ -430,17 +468,35 @@ struct interface_base : Base ...@@ -430,17 +468,35 @@ struct interface_base : Base
class = std::enable_if_t<std::is_void<X>{}>> class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs) static void call_cast_arg(rank<1>, F f, R result, X* obj, Xs... xs)
{ {
f(*reinterpret_cast<T*>(obj), result, xs...); f(*reinterpret_cast<T*>(obj), result, no_out_arg{}, xs...);
}
template <class T,
class F,
class R1,
class R2,
class X,
class... Xs,
class = std::enable_if_t<std::is_void<X>{}>>
static void call_cast_arg(rank<2>, F f, R1 result1, R2 result2, X* obj, Xs... xs)
{
f(*reinterpret_cast<T*>(obj), result1, result2, xs...);
}
template <class F, class T1, class T2, class... Ts>
void auto_invoke(F f, T1* out1, T2* out2, Ts&&... xs)
{
auto_assign(rank<2>{}, out1, out2, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, T* out, Ts&&... xs) void auto_invoke(F f, T* out, no_out_arg, Ts&&... xs)
{ {
auto_assign(rank<2>{}, out, f(std::forward<Ts>(xs)...)); auto_assign(rank<1>{}, out, f(std::forward<Ts>(xs)...));
} }
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
void auto_invoke(F f, no_out_arg, Ts&&... xs) void auto_invoke(F f, no_out_arg, no_out_arg, Ts&&... xs)
{ {
f(std::forward<Ts>(xs)...); f(std::forward<Ts>(xs)...);
} }
...@@ -469,7 +525,7 @@ struct interface_base : Base ...@@ -469,7 +525,7 @@ struct interface_base : Base
template <class T, class U> template <class T, class U>
void auto_assign(rank<0>, T* out, U x) void auto_assign(rank<0>, T* out, U x)
{ {
return *out = x; *out = x;
} }
template <class T, class U> template <class T, class U>
...@@ -477,12 +533,21 @@ struct interface_base : Base ...@@ -477,12 +533,21 @@ struct interface_base : Base
{ {
x.assign_to_handle(out); x.assign_to_handle(out);
} }
template <class T1, class T2, class U, class = std::enable_if_t<std::is_same<T2, size_t>{}>>
auto auto_assign(rank<2>, T1* out_ptr, T2* out_size, U x)
{
*out_size = std::min(*out_size, x.size());
std::copy_n(x.begin(), *out_size, out_ptr);
}
}; };
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_INTERFACE_LIFT(T, prefix, name) \ #define MIGRAPHX_INTERFACE_LIFT(n_out, T, prefix, name) \
this->set_auto_fp<T>(&migraphx_##prefix##_set_##name, \ this->set_auto_fp<T>( \
[](T& x, auto... xs) { return x.name(xs...); }) &migraphx_##prefix##_set_##name, \
[](T& x, auto... xs) { return x.name(xs...); }, \
out_params<n_out>{})
template <class Base, class T> template <class Base, class T>
using require_interface = using require_interface =
...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -567,6 +632,13 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return pout; return pout;
} }
size_t elements() const
{
size_t pout;
call(&migraphx_shape_elements, &pout, this->get_handle_ptr());
return pout;
}
size_t bytes() const size_t bytes() const
{ {
size_t pout; size_t pout;
...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -581,6 +653,14 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
return result; return result;
} }
// map element index to space index
size_t index(size_t i) const
{
size_t result;
call(&migraphx_shape_index, &result, this->get_handle_ptr(), i);
return result;
}
friend bool operator==(const shape& px, const shape& py) friend bool operator==(const shape& px, const shape& py)
{ {
bool pout; bool pout;
...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -628,9 +708,15 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
template <typename T> template <typename T>
std::vector<T> as_vector() const std::vector<T> as_vector() const
{ {
size_t vector_len = this->get_shape().bytes() / sizeof(T); auto ss = this->get_shape();
auto num_elements = ss.elements();
std::vector<T> res(num_elements);
T* buffer_ptr = reinterpret_cast<T*>(this->data()); T* buffer_ptr = reinterpret_cast<T*>(this->data());
return {buffer_ptr, buffer_ptr + vector_len}; for(size_t i = 0; i < num_elements; i++)
{
res[i] = buffer_ptr[ss.index(i)];
}
return res;
} }
/// Generate an argument using random data /// Generate an argument using random data
...@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -979,6 +1065,20 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return arguments(pout, own{}); return arguments(pout, own{});
} }
template <class Stream>
/// Overloaded to allow for execution_environment input
arguments run_async(const program_parameters& pparams, Stream* s) const
{
migraphx_arguments_t pout;
call(&migraphx_program_run_async,
&pout,
this->get_handle_ptr(),
pparams.get_handle_ptr(),
s,
get_type_name<Stream>().c_str());
return arguments(pout, own{});
}
void print() const { call(&migraphx_program_print, this->get_handle_ptr()); } void print() const { call(&migraphx_program_print, this->get_handle_ptr()); }
program sort() program sort()
...@@ -1255,6 +1355,9 @@ struct experimental_custom_op_base ...@@ -1255,6 +1355,9 @@ struct experimental_custom_op_base
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual argument compute(context ctx, shape output, arguments inputs) const = 0; virtual argument compute(context ctx, shape output, arguments inputs) const = 0;
virtual shape compute_shape(shapes inputs) const = 0; virtual shape compute_shape(shapes inputs) const = 0;
virtual std::vector<size_t> output_alias(shapes) const { return {}; }
// TODO: Return target string instead of bool
virtual bool runs_on_offload_target() const = 0;
virtual ~experimental_custom_op_base() = default; virtual ~experimental_custom_op_base() = default;
}; };
...@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental ...@@ -1267,8 +1370,10 @@ struct experimental_custom_op : interface_base<MIGRAPHX_HANDLE_BASE(experimental
obj, obj,
get_type_name(obj).c_str(), get_type_name(obj).c_str(),
obj.name().c_str()); obj.name().c_str());
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute_shape); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute_shape);
MIGRAPHX_INTERFACE_LIFT(T, experimental_custom_op, compute); MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, compute);
MIGRAPHX_INTERFACE_LIFT(2, T, experimental_custom_op, output_alias);
MIGRAPHX_INTERFACE_LIFT(1, T, experimental_custom_op, runs_on_offload_target);
} }
void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); } void register_op() { call(&migraphx_experimental_custom_op_register, this->get_handle_ptr()); }
......
...@@ -115,6 +115,7 @@ def shape(h): ...@@ -115,6 +115,7 @@ def shape(h):
const=True) const=True)
h.method('strides', returns='const std::vector<size_t>&', const=True) h.method('strides', returns='const std::vector<size_t>&', const=True)
h.method('type', returns='migraphx::shape::type_t', const=True) h.method('type', returns='migraphx::shape::type_t', const=True)
h.method('elements', returns='size_t', const=True)
h.method('bytes', returns='size_t', const=True) h.method('bytes', returns='size_t', const=True)
h.method('equal', h.method('equal',
api.params(x='const migraphx::shape&'), api.params(x='const migraphx::shape&'),
...@@ -122,6 +123,7 @@ def shape(h): ...@@ -122,6 +123,7 @@ def shape(h):
returns='bool', returns='bool',
const=True) const=True)
h.method('standard', returns='bool', const=True) h.method('standard', returns='bool', const=True)
h.method('index', api.params(i='size_t'), returns='size_t', const=True)
@auto_handle() @auto_handle()
...@@ -274,6 +276,13 @@ def program(h): ...@@ -274,6 +276,13 @@ def program(h):
params='std::unordered_map<std::string, migraphx::argument>'), params='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::run($@)', invoke='migraphx::run($@)',
returns='std::vector<migraphx::argument>') returns='std::vector<migraphx::argument>')
h.method('run_async',
api.params(
params='std::unordered_map<std::string, migraphx::argument>',
s='void*',
name='const char *'),
invoke='migraphx::run_async($@)',
returns='std::vector<migraphx::argument>')
h.method('equal', h.method('equal',
api.params(x='const migraphx::program&'), api.params(x='const migraphx::program&'),
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
...@@ -450,4 +459,8 @@ def experimental_custom_op(h): ...@@ -450,4 +459,8 @@ def experimental_custom_op(h):
h.virtual('compute_shape', h.virtual('compute_shape',
api.params(inputs='std::vector<migraphx::shape>'), api.params(inputs='std::vector<migraphx::shape>'),
returns='migraphx::shape') returns='migraphx::shape')
h.virtual('output_alias',
api.params(inputs='std::vector<migraphx::shape>'),
returns='std::vector<size_t>')
h.virtual('runs_on_offload_target', returns='bool')
h.method('register', invoke='migraphx::register_custom_op($@)') h.method('register', invoke='migraphx::register_custom_op($@)')
...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&) ...@@ -66,6 +66,15 @@ any_ptr get_queue_context(T&)
{ {
return {}; return {};
} }
template <class T>
void wait_for_context(T&, any_ptr)
{
}
template <class T>
void finish_on_context(T&, any_ptr)
{
}
#ifdef TYPE_ERASED_DECLARATION #ifdef TYPE_ERASED_DECLARATION
...@@ -78,6 +87,10 @@ struct context ...@@ -78,6 +87,10 @@ struct context
void from_value(const value& v); void from_value(const value& v);
// (optional) // (optional)
any_ptr get_queue(); any_ptr get_queue();
// (optional)
void wait_for(any_ptr queue);
// (optional)
void finish_on(any_ptr queue);
// //
void finish() const; void finish() const;
}; };
...@@ -165,6 +178,18 @@ struct context ...@@ -165,6 +178,18 @@ struct context
return (*this).private_detail_te_get_handle().get_queue(); return (*this).private_detail_te_get_handle().get_queue();
} }
void wait_for(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().wait_for(queue);
}
void finish_on(any_ptr queue)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().finish_on(queue);
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -187,6 +212,8 @@ struct context ...@@ -187,6 +212,8 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0; virtual any_ptr get_queue() = 0;
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -231,6 +258,33 @@ struct context ...@@ -231,6 +258,33 @@ struct context
return get_queue_context(private_detail_te_self); return get_queue_context(private_detail_te_self);
} }
template <class T>
static auto private_detail_te_default_wait_for(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.wait_for(queue))
{
private_detail_te_self.wait_for(queue);
}
template <class T>
static void private_detail_te_default_wait_for(float, T&& private_detail_te_self, any_ptr queue)
{
wait_for_context(private_detail_te_self, queue);
}
template <class T>
static auto private_detail_te_default_finish_on(char, T&& private_detail_te_self, any_ptr queue)
-> decltype(private_detail_te_self.finish_on(queue))
{
private_detail_te_self.finish_on(queue);
}
template <class T>
static void
private_detail_te_default_finish_on(float, T&& private_detail_te_self, any_ptr queue)
{
finish_on_context(private_detail_te_self, queue);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -248,7 +302,7 @@ struct context ...@@ -248,7 +302,7 @@ struct context
PrivateDetailTypeErasedT value, PrivateDetailTypeErasedT value,
typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value, typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value)) : private_detail_te_value(value)
{ {
} }
...@@ -277,6 +331,18 @@ struct context ...@@ -277,6 +331,18 @@ struct context
return private_detail_te_default_get_queue(char(0), private_detail_te_value); return private_detail_te_default_get_queue(char(0), private_detail_te_value);
} }
void wait_for(any_ptr queue) override
{
private_detail_te_default_wait_for(char(0), private_detail_te_value, queue);
}
void finish_on(any_ptr queue) override
{
private_detail_te_default_finish_on(char(0), private_detail_te_value, queue);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP
#include <migraphx/any_ptr.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct execution_environment
{
any_ptr queue = any_ptr{};
bool async = false;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif /* MIGRAPHX_GUARD_MIGRAPHLIB_EXECUTION_ENV_HPP */
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include <migraphx/assignment_options.hpp> #include <migraphx/assignment_options.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/execution_environment.hpp>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -76,8 +77,8 @@ struct program ...@@ -76,8 +77,8 @@ struct program
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
std::vector<argument> eval(parameter_map params) const; std::vector<argument> eval(parameter_map params,
execution_environment exec_env = execution_environment{}) const;
std::size_t size() const; std::size_t size() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <algorithm> #include <algorithm>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -59,28 +60,35 @@ inline stream_range_container<Range> stream_range(const Range& r) ...@@ -59,28 +60,35 @@ inline stream_range_container<Range> stream_range(const Range& r)
namespace detail { namespace detail {
inline void stream_write_value_impl(rank<2>, std::ostream& os, const std::string& x) { os << x; } template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const T& x) -> decltype(os << x, void())
{
os << x;
}
template <class Range> template <class T>
auto stream_write_value_impl(rank<1>, std::ostream& os, const Range& r) void stream_write_value_impl(rank<1>, std::ostream& os, const std::vector<T>& r)
-> decltype(r.begin(), r.end(), void())
{ {
os << "{"; os << "{";
os << stream_range(r); os << stream_range(r);
os << "}"; os << "}";
} }
template <class T> template <class Range>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x) auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
-> decltype(r.begin(), r.end(), void())
{ {
os << x; os << "{";
os << stream_range(r);
os << "}";
} }
} // namespace detail } // namespace detail
template <class T> template <class T>
void stream_write_value(std::ostream& os, const T& x) void stream_write_value(std::ostream& os, const T& x)
{ {
detail::stream_write_value_impl(rank<2>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -184,6 +184,12 @@ struct value ...@@ -184,6 +184,12 @@ struct value
{ {
} }
explicit binary(std::size_t s) : base(s) {} explicit binary(std::size_t s) : base(s) {}
friend std::ostream& operator<<(std::ostream& os, const binary& obj)
{
os << "{binary_object: " << obj.size() << "}";
return os;
}
}; };
value() = default; value() = default;
......
...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds ...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst) instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{ {
this->move_instruction(src, dst);
for(auto ins : src->inputs()) for(auto ins : src->inputs())
this->move_instruction(ins, src); {
if(not contains(this->impl->instructions, ins))
continue;
this->move_instructions(ins, dst);
}
this->move_instruction(src, dst);
return src; return src;
} }
......
...@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p, ...@@ -398,7 +398,7 @@ std::vector<argument> generic_eval(const program& p,
return generic_eval(mm, ctx, params, {}, make_trace); return generic_eval(mm, ctx, params, {}, make_trace);
} }
std::vector<argument> program::eval(parameter_map params) const std::vector<argument> program::eval(parameter_map params, execution_environment exec_env) const
{ {
auto& ctx = this->impl->ctx; auto& ctx = this->impl->ctx;
#ifndef NDEBUG #ifndef NDEBUG
...@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -423,6 +423,12 @@ std::vector<argument> program::eval(parameter_map params) const
#endif #endif
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{}); auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
std::vector<argument> ret;
if(exec_env.async)
{
ctx.wait_for(exec_env.queue);
}
if(trace_level > 0) if(trace_level > 0)
{ {
...@@ -434,7 +440,7 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -434,7 +440,7 @@ std::vector<argument> program::eval(parameter_map params) const
ins_out[x] = ss.str(); ins_out[x] = ss.str();
}); });
return generic_eval(*this, ret = generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto& ins, auto f, auto&& check_context) { with_check_context([&](auto& ins, auto f, auto&& check_context) {
...@@ -470,13 +476,20 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -470,13 +476,20 @@ std::vector<argument> program::eval(parameter_map params) const
} }
else else
{ {
return generic_eval(*this, ret = generic_eval(*this,
ctx, ctx,
std::move(params), std::move(params),
with_check_context([&](auto&, auto f, auto&& check_context) { with_check_context([&](auto&, auto f, auto&& check_context) {
return check_context(f); return check_context(f);
})); }));
} }
if(exec_env.async)
{
ctx.finish_on(exec_env.queue);
}
return ret;
} }
const int program_file_version = 5; const int program_file_version = 5;
......
...@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -355,6 +355,23 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
} }
return p.eval(pm); return p.eval(pm);
}) })
.def("run_async",
[](migraphx::program& p,
py::dict params,
std::uintptr_t stream,
std::string stream_name) {
migraphx::parameter_map pm;
for(auto x : params)
{
std::string key = x.first.cast<std::string>();
py::buffer b = x.second.cast<py::buffer>();
py::buffer_info info = b.request();
pm[key] = migraphx::argument(to_shape(info), info.ptr);
}
migraphx::execution_environment exec_env{
migraphx::any_ptr(reinterpret_cast<void*>(stream), stream_name), true};
return p.eval(pm, exec_env);
})
.def("sort", &migraphx::program::sort) .def("sort", &migraphx::program::sort)
.def("print", [](const migraphx::program& p) { std::cout << p << std::endl; }) .def("print", [](const migraphx::program& p) { std::cout << p << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{}) .def("__eq__", std::equal_to<migraphx::program>{})
......
...@@ -435,6 +435,24 @@ struct find_concat_op ...@@ -435,6 +435,24 @@ struct find_concat_op
} }
}; };
void move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
auto start = range(m.begin(), pos);
for(auto ins : iterator_for(start))
{
auto it = std::find(inss.begin(), inss.end(), ins);
if(it != inss.end())
inss.erase(it);
}
for(auto ins : inss)
{
if(not m.has_instruction(ins))
continue;
move_instructions_back(m, pos, ins->inputs());
m.move_instruction(ins, pos);
}
}
std::vector<instruction_ref> get_splits(instruction_ref ins) std::vector<instruction_ref> get_splits(instruction_ref ins)
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
...@@ -610,8 +628,7 @@ struct find_splits ...@@ -610,8 +628,7 @@ struct find_splits
})) }))
return; return;
for(auto data : data_args) move_instructions_back(m, ins, data_args);
m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator()); auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty()); assert(not slice_op.axes.empty());
...@@ -864,8 +881,7 @@ struct find_conv_dot_horiz_fusion ...@@ -864,8 +881,7 @@ struct find_conv_dot_horiz_fusion
concat_axis = axis; concat_axis = axis;
} }
for(auto arg : args) move_instructions_back(m, input, args);
m.move_instructions(arg, input);
// TODO: Check if axes match // TODO: Check if axes match
auto concat = auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
......
...@@ -239,9 +239,18 @@ endif() ...@@ -239,9 +239,18 @@ endif()
include(CheckLibraryExists) include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION) get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API) check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
if(HAS_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
message(STATUS "MIOpen does not have Find-2.0 API")
endif()
if(HAS_FIND_MODE_API) if(HAS_FIND_MODE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API) target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_MODE_API)
message(STATUS "MIOpen has find mode api") message(STATUS "MIGraphx is using Find Mode API of MIOpen")
else() else()
message(STATUS "MIOpen does not have find mode api") message(STATUS "MIOpen does not have find mode api")
endif() endif()
......
...@@ -131,7 +131,7 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) ...@@ -131,7 +131,7 @@ preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
std::size_t bytes = 0; std::size_t bytes = 0;
for(auto i : preloaded) for(auto i : preloaded)
{ {
auto input = inputs[i]; const auto& input = inputs[i];
bytes += input.bytes(); bytes += input.bytes();
if(bytes > max_lds_bytes) if(bytes > max_lds_bytes)
break; break;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <miopen/miopen.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -58,11 +59,37 @@ argument miopen_convolution::compute(context& ctx, ...@@ -58,11 +59,37 @@ argument miopen_convolution::compute(context& ctx,
auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape())); auto x_desc = make_tensor(reshape_if_1d(args[0].get_shape()));
auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape())); auto w_desc = make_tensor(reshape_if_1d(args[1].get_shape()));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
auto* miopen_stream_handle = ctx.get_stream().get_miopen();
auto workspace_size = args[2].get_shape().bytes();
#ifdef MIGRAPHX_HAS_FIND_2_API
{
const miopenTensorArgument_t tensor_args[3] = {
{miopenTensorConvolutionX, nullptr, args[0].implicit()},
{miopenTensorConvolutionW, nullptr, args[1].implicit()},
{miopenTensorConvolutionY, nullptr, args[3].implicit()},
};
if(solution_ptr.get() == nullptr)
MIGRAPHX_THROW("MIOpen Convolution : Load MIOpen Solution before running it");
auto status = miopenRunSolution(miopen_stream_handle,
solution_ptr.get(),
3,
tensor_args,
args[2].implicit(),
workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: running convolution using find_2.0 failed");
return args[3];
}
#else
// else use immediate mode
if(solution_id == 0) if(solution_id == 0)
MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID"); MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID");
auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(), auto status = miopenConvolutionForwardImmediate(miopen_stream_handle,
w_desc.get(), w_desc.get(),
args[1].implicit(), args[1].implicit(),
x_desc.get(), x_desc.get(),
...@@ -71,29 +98,66 @@ argument miopen_convolution::compute(context& ctx, ...@@ -71,29 +98,66 @@ argument miopen_convolution::compute(context& ctx,
y_desc.get(), y_desc.get(),
args[3].implicit(), args[3].implicit(),
args[2].implicit(), args[2].implicit(),
args[2].get_shape().bytes(), workspace_size,
solution_id); solution_id);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: running convolution failed"); MIGRAPHX_THROW("MIOpen Convolution: running convolution failed");
return args[3]; return args[3];
#endif
} }
shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vector<shape> inputs) shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vector<shape> inputs)
{ {
shape workspace_shape{}; shape workspace_shape{};
auto x_desc = make_tensor(reshape_if_1d(inputs[0])); auto x_desc = make_tensor(reshape_if_1d(inputs[0]));
auto w_desc = make_tensor(reshape_if_1d(inputs[1])); auto w_desc = make_tensor(reshape_if_1d(inputs[1]));
auto y_desc = make_tensor(reshape_if_1d(output_shape)); auto y_desc = make_tensor(reshape_if_1d(output_shape));
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
#ifdef MIGRAPHX_HAS_FIND_2_API
{
auto conv_problem = make_obj<miopen_problem>(
&miopenCreateConvProblem, cd.get(), miopenProblemDirectionForward);
set_tensor_descriptor(miopenTensorConvolutionX, x_desc, conv_problem);
set_tensor_descriptor(miopenTensorConvolutionW, w_desc, conv_problem);
set_tensor_descriptor(miopenTensorConvolutionY, y_desc, conv_problem);
auto* miopen_stream_handle = ctx.get_stream().get_miopen();
solution_ptr = find_solution(miopen_stream_handle, conv_problem.get());
auto status = miopenGetSolutionWorkspaceSize(solution_ptr.get(), &workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution : failed to get solution's workspace size");
std::size_t solution_size;
status = miopenGetSolutionSize(solution_ptr.get(), &solution_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: Failed to fetch solution size");
auto solution_binary = std::vector<char>{};
solution_binary.resize(solution_size);
status = miopenSaveSolution(solution_ptr.get(), solution_binary.data());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: Saving solution failed");
solution_object = value::binary{solution_binary.data(), solution_size};
return shape{shape::int8_type, {workspace_size}};
}
#else
// else use immediate find mode
auto status = miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(), w_desc.get(),
x_desc.get(), x_desc.get(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
&workspace_size); &workspace_size);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: Failed to get forward workspace size");
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0])); auto x = to_gpu(generate_argument(inputs[0]));
...@@ -103,7 +167,7 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec ...@@ -103,7 +167,7 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
int algo_count = 1; int algo_count = 1;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(), status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(), x_desc.get(),
x.implicit(), x.implicit(),
w_desc.get(), w_desc.get(),
...@@ -148,12 +212,33 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec ...@@ -148,12 +212,33 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
solution_id = solutions.front().solution_id; solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {perf.memory}};
#endif
} }
void miopen_convolution::finalize(context& ctx, void miopen_convolution::finalize(context& ctx,
const shape& output_shape, const shape& output_shape,
std::vector<shape> inputs) const std::vector<shape>& inputs)
{ {
#ifdef MIGRAPHX_HAS_FIND_2_API
{
(void)(ctx); // avoid warnings
(void)(output_shape);
(void)(inputs);
// load solution
if(solution_ptr == nullptr)
{
miopenSolution_t ptr;
auto status = miopenLoadSolution(&ptr,
reinterpret_cast<const char*>(solution_object.data()),
solution_object.size());
solution_ptr = miopen_solution{ptr};
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: loading convolution solution failed");
}
}
#else
// Use immediate mode API
{
if(cd == nullptr) if(cd == nullptr)
cd = make_conv(op); cd = make_conv(op);
if(solution_id == 0) if(solution_id == 0)
...@@ -177,6 +262,8 @@ void miopen_convolution::finalize(context& ctx, ...@@ -177,6 +262,8 @@ void miopen_convolution::finalize(context& ctx,
solution_id); solution_id);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: compile solution failed"); MIGRAPHX_THROW("MIOpen Convolution: compile solution failed");
}
#endif
} }
} // namespace gpu } // namespace gpu
......
...@@ -197,7 +197,9 @@ struct hip_device ...@@ -197,7 +197,9 @@ struct hip_device
struct context struct context
{ {
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1)) context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n)) : current_device(std::make_shared<hip_device>(device_id, n)),
begin_event(create_event()),
finish_event(create_event())
{ {
} }
...@@ -274,6 +276,24 @@ struct context ...@@ -274,6 +276,24 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
void wait_for(any_ptr queue)
{
auto status = hipEventRecord(begin_event.get(), queue.get<hipStream_t>());
if(status != hipSuccess)
MIGRAPHX_THROW("failed to record " + hip_error(status));
get_stream().wait(begin_event.get());
}
void finish_on(any_ptr queue)
{
get_stream().record(finish_event.get());
auto status = hipStreamWaitEvent(queue.get<hipStream_t>(), finish_event.get(), 0);
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to wait on event " + hip_error(status));
}
any_ptr get_queue() { return get_stream().get(); } any_ptr get_queue() { return get_stream().get(); }
void enable_perf_measurement(bool b = true) void enable_perf_measurement(bool b = true)
...@@ -317,8 +337,12 @@ struct context ...@@ -317,8 +337,12 @@ struct context
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
std::vector<shared<hip_event_ptr>> events; std::vector<shared<hip_event_ptr>> events;
bool measure_perf = false; bool measure_perf = false;
// for event perf timing
shared<hip_event_ptr> start_event = nullptr; shared<hip_event_ptr> start_event = nullptr;
shared<hip_event_ptr> stop_event = nullptr; shared<hip_event_ptr> stop_event = nullptr;
// for stream syncronization
shared<hip_event_ptr> begin_event = nullptr;
shared<hip_event_ptr> finish_event = nullptr;
}; };
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); } inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment