Commit dfa26315 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_shape_update

parents aa085491 4a5a23a4
...@@ -4,7 +4,7 @@ CheckOptions: ...@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions - key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product' value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp - key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_' value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED) ...@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.2) rocm_setup_version(VERSION 2.3)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -109,6 +109,7 @@ register_migraphx_ops( ...@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten flatten
floor floor
gather gather
gathernd
get_tuple_elem get_tuple_elem
greater greater
gru gru
......
...@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction; ...@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct migraphx_instruction struct migraphx_instruction
{ {
template <class... Ts> template <class... Ts>
migraphx_instruction(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_instruction(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::instruction_ref object; migraphx::instruction_ref object;
...@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions; ...@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct migraphx_instructions struct migraphx_instructions
{ {
template <class... Ts> template <class... Ts>
migraphx_instructions(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_instructions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::instruction_ref> object; std::vector<migraphx::instruction_ref> object;
...@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules; ...@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct migraphx_modules struct migraphx_modules
{ {
template <class... Ts> template <class... Ts>
migraphx_modules(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_modules(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::module*> object; std::vector<migraphx::module*> object;
...@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont ...@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
*out = (context->object).get_queue().unsafe_get();
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op) migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{ {
......
...@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status migraphx_context_finish(const_migraphx_context_t context); migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context);
migraphx_status migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
......
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N> template <int N>
struct rank : rank<N - 1> struct rank : rank<N - 1>
{ {
...@@ -99,34 +109,22 @@ struct iota_iterator ...@@ -99,34 +109,22 @@ struct iota_iterator
return it; return it;
} }
// TODO: operator-> // TODO: operator->
reference operator*() const { return (*f)(index); } reference operator*() const { return f(index); }
};
template <class F, class Iterator> friend iota_iterator operator+(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index + y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
template <class F, class Iterator> friend iota_iterator operator-(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index - y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
template <class F, class Iterator> friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator> friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) };
{
return x.index != y.index;
}
template <class Derived> template <class Derived>
struct array_base struct array_base
...@@ -136,8 +134,20 @@ struct array_base ...@@ -136,8 +134,20 @@ struct array_base
template <class T> template <class T>
using value_type_t = decltype(std::declval<T>()[0]); using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T> template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>; using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived> template <class D = Derived>
value_type_t<D> front() const value_type_t<D> front() const
...@@ -154,13 +164,13 @@ struct array_base ...@@ -154,13 +164,13 @@ struct array_base
template <class D = Derived> template <class D = Derived>
iterator_t<D> begin() const iterator_t<D> begin() const
{ {
return {0, {derived().get_handle_ptr()}}; return {0, {&derived()}};
} }
template <class D = Derived> template <class D = Derived>
iterator_t<D> end() const iterator_t<D> end() const
{ {
return {derived().size(), {derived().get_handle_ptr()}}; return {derived().size(), {&derived()}};
} }
}; };
...@@ -200,9 +210,25 @@ struct borrow ...@@ -200,9 +210,25 @@ struct borrow
{ {
}; };
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
using handle_type = T;
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
void make_handle(F f, Ts&&... xs) void make_handle(F f, Ts&&... xs)
...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U> template <class U>
void assign_to_handle(U* x) void assign_to_handle(U* x)
{ {
...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
shape() {} shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout}; return {pout, this->share_handle()};
} }
char* data() const char* data() const
...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(target);
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own) MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
size_t size() const size_t size() const
{ {
...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout}; return {pout, this->share_handle()};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); } program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
}
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return {pout};
} }
};
}; };
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
}
struct iterator_read
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return {pout};
} }
};
}; };
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -711,33 +711,36 @@ struct module; ...@@ -711,33 +711,36 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
{ {
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...}; std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size()); this->make_handle(&migraphx_modules_create, a.data(), a.size());
} }
}; };
struct module struct module
{ {
migraphx_module_t mm; MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(const migraphx_module_t& m) : mm(m) {} template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm); } void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args) instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{ {
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction, call(&migraphx_module_add_instruction,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr()); args.get_handle_ptr());
return instruction(op_ins, own{}); return instruction(op_ins, own{});
...@@ -750,7 +753,7 @@ struct module ...@@ -750,7 +753,7 @@ struct module
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args, call(&migraphx_module_add_instruction_with_mod_args,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr(), args.get_handle_ptr(),
module_args.get_handle_ptr()); module_args.get_handle_ptr());
...@@ -760,30 +763,53 @@ struct module ...@@ -760,30 +763,53 @@ struct module
instruction add_parameter(const std::string& name, shape s) instruction add_parameter(const std::string& name, shape s)
{ {
migraphx_instruction_t param_ins; migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr()); call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{}); return instruction(param_ins, own{});
} }
instruction add_return(const migraphx::instructions& args) instruction add_return(const migraphx::instructions& args)
{ {
migraphx_instruction_t ret_ins; migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr()); call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context
{ {
migraphx_context_t ctx; context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
void finish() const { call(&migraphx_context_finish, ctx); } template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
private:
std::shared_ptr<migraphx_context> ctx;
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); } MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program);
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr()); call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
context experimental_get_context() context experimental_get_context()
{ {
migraphx_context_t ctx; migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr()); call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx}; return context{ctx, this->share_handle()};
} }
module create_module(const std::string& name) module create_module(const std::string& name)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data()); call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
...@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format // set file format
void set_file_format(const char* format) void set_file_format(const char* format)
{ {
...@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8', ...@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True) @auto_handle(ref=True)
def context(h): def context(h):
h.method('finish', const=True) h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct gathernd
{
int batch_dims = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}
std::string name() const { return "gathernd"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}
std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;
auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});
par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp> #include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
......
...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"}, {"Flatten", "flatten"},
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"}, {"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
//! Parser for ReverseSequence ONNX operator.
/*!
Reverses the data along the time axis for the batches along the batch axis.
The sequence lengths can be given to reverse up to the given length for each batch, keeping the
rest of the sequence in the original order. Variable sequence_lens is not supported in this
version of MIGraphX. You can pass the sequence_lens either as a constant node or an attribute. The
batch axis and time axis must be [0, 1] and not the same.
*/
struct parse_reversesequence : op_parser<parse_reversesequence>
{
std::vector<op_desc> operators() const { return {{"ReverseSequence"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int batch_axis = 1;
if(contains(info.attributes, "batch_axis"))
{
batch_axis = info.attributes.at("batch_axis").i();
}
if(batch_axis != 0 and batch_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: batch axis not 0 or 1");
}
int time_axis = 0;
if(contains(info.attributes, "time_axis"))
{
time_axis = info.attributes.at("time_axis").i();
}
if(time_axis != 0 and time_axis != 1)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis not 0 or 1");
}
if(time_axis == batch_axis)
{
MIGRAPHX_THROW("REVERSESEQUENCE: time axis and batch axis are the same");
}
auto input = args[0];
auto input_lens = input->get_shape().lens();
if(input_lens.size() < 2)
{
MIGRAPHX_THROW("REVERSESEQUENCE: input tensor must have rank >= 2");
}
std::vector<int64_t> sequence_lens;
if(args.size() == 2)
{
migraphx::argument seq_lens_arg = args.back()->eval();
check_arg_empty(seq_lens_arg, "REVERSESEQUENCE: cannot handle variable sequence_lens");
seq_lens_arg.visit([&](auto s) { sequence_lens.assign(s.begin(), s.end()); });
}
else if(contains(info.attributes, "sequence_lens"))
{
literal s = parser.parse_value(info.attributes.at("sequence_lens"));
s.visit([&](auto v) { sequence_lens.assign(v.begin(), v.end()); });
}
auto batch_size = input_lens[batch_axis];
auto time_size = input_lens[time_axis];
// this condition may still work if sequence_len's shape was incorrect
if(sequence_lens.size() != batch_size)
{
MIGRAPHX_THROW("REVERSESEQUENCE: sequence_lens has incorrect shape");
}
instruction_ref ret;
auto add_slice = [&info, &input, batch_axis, time_axis](int b, int t_start, int t_end) {
return info.add_instruction(make_op("slice",
{{"axes", {batch_axis, time_axis}},
{"starts", {b, t_start}},
{"ends", {b + 1, t_end}}}),
input);
};
for(int b = 0; b < batch_size; ++b)
{
instruction_ref s0;
if(sequence_lens[b] > 1)
{
s0 = add_slice(b, 0, sequence_lens[b]);
s0 = info.add_instruction(make_op("reverse", {{"axes", {time_axis}}}), s0);
// if reversed less than whole batch, concat rest of batch
if(sequence_lens[b] < time_size)
{
auto s1 = add_slice(b, sequence_lens[b], time_size);
s0 = info.add_instruction(make_op("concat", {{"axis", time_axis}}), s0, s1);
}
}
else
{ // cases where nothing changes
s0 = add_slice(b, 0, time_size);
}
if(b == 0)
{
ret = s0;
}
else
{
ret = info.add_instruction(make_op("concat", {{"axis", batch_axis}}), ret, s0);
}
}
return ret;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -20,7 +20,7 @@ struct run_op : action<run_op> ...@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p) ...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write}); simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
}); });
} }
...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input ...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return get_reduce_elements(to_shapes(inputs)); return get_reduce_elements(to_shapes(inputs));
} }
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler> struct reduce_compiler : compiler<reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
hip_compile_options options; hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs); auto reduce_elements = get_reduce_elements(inputs);
auto algo = v.get("algo", get_reduce_algo(inputs));
if(algo == "block")
{
auto block_size = compute_block_size(reduce_elements, 256); auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size); v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt> template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{ {
......
...@@ -42,6 +42,32 @@ struct print_buffer ...@@ -42,6 +42,32 @@ struct print_buffer
pos++; pos++;
} }
} }
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M> template <size_t M>
constexpr void append(const char (&array)[M]) constexpr void append(const char (&array)[M])
...@@ -54,14 +80,36 @@ struct print_buffer ...@@ -54,14 +80,36 @@ struct print_buffer
template <class... Ts> template <class... Ts>
__host__ __device__ void print(const Ts&... xs) __host__ __device__ void print(const Ts&... xs)
{ {
const auto size = (sizeof(xs) + ...); print_buffer<1024> buffer;
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...}; swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer); printf("%s", buffer.buffer);
} }
} // namespace debug } // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4> template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct ...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort(); abort();
} }
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \ #define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__)) }(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK #define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK #define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false) #define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else #else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume #define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable #define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif #endif
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
#endif #endif
template <class Input, class T, class Output> template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i, Output) constexpr auto reduce_slice(Input input, T i)
{ {
constexpr auto lens = transform(get_shape_c<Input>{}.lens, constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens, get_shape_c<Output>{}.lens,
...@@ -136,23 +136,126 @@ constexpr auto reduce_slice(Input input, T i, Output) ...@@ -136,23 +136,126 @@ constexpr auto reduce_slice(Input input, T i, Output)
}); });
; ;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides); constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s); return make_tensor_view(&input[i], s);
} }
template <class Op, class T, class Input, class Output, class ReadInput, class WriteOuput> namespace reduce {
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write) template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
return [=](auto x, auto... xs) {
// TODO: assert all elements are the same
return f(slicer(x), slicer(xs)...);
};
}
struct block
{ {
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
});
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index(); auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements(); constexpr auto nelements = get_shape_c<Output>{}.elements();
constexpr auto relements = get_shape_c<Input>{}.elements() / get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) { idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = output.get_shape().multi(i / idx.nlocal()); const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
auto rs = reduce_slice(input, out_idx, output); f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
MIGRAPHX_ASSERT(relements == rs.get_shape().elements()); });
auto r = block_reduce(idx, op, init, relements, [&](auto j) { return read(rs[j]); }); }
if(idx.local == 0) };
output[out_idx] = write(r);
struct lane
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
using type = typename decltype(x)::type;
type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
r = op(r, read(x[j], xs[j]...));
}
return r;
});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements, [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i);
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
} // namespace reduce
template <class Algo,
class Op,
class T,
class Input,
class Output,
class ReadInput,
class WriteOuput>
__device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write)
{
Algo::template run<Output>([&](auto out_idx, auto r) {
auto x = r.reduce(op, init, read)(input);
r.outer([&] { output[out_idx] = write(x); });
}); });
} }
......
...@@ -29,11 +29,23 @@ struct tensor_view ...@@ -29,11 +29,23 @@ struct tensor_view
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr auto size() const { return get_shape().elements(); } constexpr auto size() const { return get_shape().elements(); }
struct index_to_offset
{
index_int offset;
template <class U> template <class U>
constexpr T& operator[](U i) const constexpr index_to_offset(U i) : offset(Shape{}.index(i))
{ {
MIGRAPHX_ASSERT(get_shape().index(i) < get_shape().element_space()); }
return x[get_shape().index(i)]; };
constexpr T& operator[](MIGRAPHX_CAPTURE_SOURCE_LOCATION(index_to_offset) i) const
{
index_to_offset ito = i;
MIGRAPHX_WARN(ito.offset < get_shape().element_space(),
i,
"Out of bounds access at offset: ",
ito.offset);
return x[ito.offset];
} }
constexpr T* data() const { return x; } constexpr T* data() const { return x; }
......
function(add_api_test TEST_NAME TEST_SRC TEST_DIR) function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
...@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -10,6 +9,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_dependencies(check ${NAME}) add_dependencies(check ${NAME})
endfunction() endfunction()
add_api_test(array_base test_array_base.cpp ${TEST_ONNX_DIR})
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
...@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) ...@@ -19,7 +19,8 @@ add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR}) add_api_test(tf_parser test_tf_parser.cpp ${TEST_TF_DIR})
# GPU-based tests
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
# GPU-based tests target_link_libraries(test_api_gpu migraphx_gpu)
endif() endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment