Commit 988bf26b authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/jit-softmax' into HEAD

parents f99a3036 bb0fff52
...@@ -4,7 +4,7 @@ CheckOptions: ...@@ -4,7 +4,7 @@ CheckOptions:
- key: bugprone-unused-return-value.CheckedFunctions - key: bugprone-unused-return-value.CheckedFunctions
value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product' value: '::std::async;::std::launder;::std::remove;::std::remove_if;::std::unique;::std::unique_ptr::release;::std::basic_string::empty;::std::vector::empty;::std::find;::std::find_if;::std::find_if_not;::std::all_of;::std::any_of;::std::none_of;::std::count;::std::count_if;::std::mismatch;::std::find_end;::std::find_first_of;::std::adjacent_find;::std::search;::std::search_n;::std::nth_element;::std::lower_bound;::std::upper_bound;::std::binary_search;::std::equal_range;::std::max;::std::max_element;::std::min;::std::min_element;::std::minmax;::std::minmax_element;::std::equal;::std::lexicographical_compare;::std::accumulate;::std::inner_product'
- key: cppcoreguidelines-macro-usage.AllowedRegexp - key: cppcoreguidelines-macro-usage.AllowedRegexp
value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_' value: 'DEBUG|ASSERT|ASSUME|UNREACHABLE|FALLTHROUGH|DEPRECATED|STRINGIZE|_HAS_|_THROW|_REQUIRES|_DECLARE_|_VISIT_|_REGISTER_|_GENERATE_|_DETAIL_|_TIDY_|_MANAGE_PTR|_MATCHER|DEVICE_SHARED|_WORKAROUND_'
- key: modernize-loop-convert.MinConfidence - key: modernize-loop-convert.MinConfidence
value: risky value: risky
- key: modernize-loop-convert.NamingStyle - key: modernize-loop-convert.NamingStyle
......
...@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED) ...@@ -42,7 +42,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 2.2) rocm_setup_version(VERSION 2.3)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -109,6 +109,7 @@ register_migraphx_ops( ...@@ -109,6 +109,7 @@ register_migraphx_ops(
flatten flatten
floor floor
gather gather
gathernd
get_tuple_elem get_tuple_elem
greater greater
gru gru
......
...@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction; ...@@ -401,7 +401,8 @@ extern "C" struct migraphx_instruction;
struct migraphx_instruction struct migraphx_instruction
{ {
template <class... Ts> template <class... Ts>
migraphx_instruction(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_instruction(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
migraphx::instruction_ref object; migraphx::instruction_ref object;
...@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions; ...@@ -411,7 +412,8 @@ extern "C" struct migraphx_instructions;
struct migraphx_instructions struct migraphx_instructions
{ {
template <class... Ts> template <class... Ts>
migraphx_instructions(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_instructions(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::instruction_ref> object; std::vector<migraphx::instruction_ref> object;
...@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules; ...@@ -421,7 +423,8 @@ extern "C" struct migraphx_modules;
struct migraphx_modules struct migraphx_modules
{ {
template <class... Ts> template <class... Ts>
migraphx_modules(Ts&&... xs) : object(std::forward<Ts>(xs)...) migraphx_modules(Ts&&... xs)
: object(std::forward<Ts>(xs)...) // NOLINT(readability-redundant-member-init)
{ {
} }
std::vector<migraphx::module*> object; std::vector<migraphx::module*> object;
...@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont ...@@ -1691,6 +1694,16 @@ extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t cont
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
*out = (context->object).get_queue().unsafe_get();
});
return api_error_result;
}
extern "C" migraphx_status extern "C" migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op) migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op)
{ {
......
...@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -433,6 +433,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_status migraphx_context_finish(const_migraphx_context_t context); migraphx_status migraphx_context_finish(const_migraphx_context_t context);
migraphx_status migraphx_context_get_queue(void** out, migraphx_context_t context);
migraphx_status migraphx_status
migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op); migraphx_experimental_custom_op_destroy(migraphx_experimental_custom_op_t experimental_custom_op);
......
...@@ -15,6 +15,16 @@ namespace migraphx { ...@@ -15,6 +15,16 @@ namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
#endif #endif
#ifdef __has_cpp_attribute
#if __has_cpp_attribute(deprecated)
#define MIGRAPHX_DEPRECATED(...) [[deprecated(__VA_ARGS__)]]
#endif
#endif
#ifndef MIGRAPHX_DEPRECATED
#define MIGRAPHX_DEPRECATED(...)
#endif
template <int N> template <int N>
struct rank : rank<N - 1> struct rank : rank<N - 1>
{ {
...@@ -99,34 +109,22 @@ struct iota_iterator ...@@ -99,34 +109,22 @@ struct iota_iterator
return it; return it;
} }
// TODO: operator-> // TODO: operator->
reference operator*() const { return (*f)(index); } reference operator*() const { return f(index); }
};
template <class F, class Iterator> friend iota_iterator operator+(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator+(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index + y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index + y.index, x.f);
}
template <class F, class Iterator> friend iota_iterator operator-(iota_iterator x, iota_iterator y)
inline iota_iterator<F, Iterator> operator-(iota_iterator<F, Iterator> x, {
iota_iterator<F, Iterator> y) return iota_iterator(x.index - y.index, x.f);
{ }
return iota_iterator<F, Iterator>(x.index - y.index, x.f);
}
template <class F, class Iterator> friend bool operator==(iota_iterator x, iota_iterator y) { return x.index == y.index; }
inline bool operator==(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator> friend bool operator!=(iota_iterator x, iota_iterator y) { return x.index != y.index; }
inline bool operator!=(iota_iterator<F, Iterator> x, iota_iterator<F, Iterator> y) };
{
return x.index != y.index;
}
template <class Derived> template <class Derived>
struct array_base struct array_base
...@@ -136,8 +134,20 @@ struct array_base ...@@ -136,8 +134,20 @@ struct array_base
template <class T> template <class T>
using value_type_t = decltype(std::declval<T>()[0]); using value_type_t = decltype(std::declval<T>()[0]);
struct iterator_read
{
const Derived* self;
template <class D = Derived>
value_type_t<D> operator()(size_t pidx) const
{
return (*self)[pidx];
}
};
template <class T> template <class T>
using iterator_t = iota_iterator<typename T::iterator_read>; using iterator_t = iota_iterator<iterator_read>;
bool empty() const { return derived().size() == 0; }
template <class D = Derived> template <class D = Derived>
value_type_t<D> front() const value_type_t<D> front() const
...@@ -154,13 +164,13 @@ struct array_base ...@@ -154,13 +164,13 @@ struct array_base
template <class D = Derived> template <class D = Derived>
iterator_t<D> begin() const iterator_t<D> begin() const
{ {
return {0, {derived().get_handle_ptr()}}; return {0, {&derived()}};
} }
template <class D = Derived> template <class D = Derived>
iterator_t<D> end() const iterator_t<D> end() const
{ {
return {derived().size(), {derived().get_handle_ptr()}}; return {derived().size(), {&derived()}};
} }
}; };
...@@ -200,9 +210,25 @@ struct borrow ...@@ -200,9 +210,25 @@ struct borrow
{ {
}; };
template <class T>
struct share
{
share(std::shared_ptr<T> p) : ptr(std::move(p)) {}
template <class U>
std::shared_ptr<U> alias(U* p) const
{
return std::shared_ptr<U>{ptr, p};
}
private:
std::shared_ptr<T> ptr;
};
template <class Derived, class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
using handle_type = T;
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
void make_handle(F f, Ts&&... xs) void make_handle(F f, Ts&&... xs)
...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -231,6 +257,14 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
m_handle = std::shared_ptr<U>{ptr, [](U*) {}}; m_handle = std::shared_ptr<U>{ptr, [](U*) {}};
} }
template <class U, class V>
void set_handle(U* ptr, share<V> b)
{
m_handle = std::shared_ptr<T>{ptr, [b](U*) {}};
}
share<T> share_handle() const { return {m_handle}; }
template <class U> template <class U>
void assign_to_handle(U* x) void assign_to_handle(U* x)
{ {
...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>> ...@@ -241,6 +275,17 @@ struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
std::shared_ptr<T> m_handle; std::shared_ptr<T> m_handle;
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_HANDLE_CONSTRUCTOR(name) \
template <class HandleType, \
class Lifetime, \
class = \
typename std::enable_if<std::is_convertible<HandleType*, handle_type*>{}>::type> \
name(HandleType* p, Lifetime lifetime) \
{ \
this->set_handle(p, std::move(lifetime)); \
}
template <class Base> template <class Base>
struct interface_base : Base struct interface_base : Base
{ {
...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape) ...@@ -398,11 +443,10 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
{ {
shape() {} shape() {}
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); } shape(const migraphx_shape* p) { this->set_handle(p, borrow{}); }
shape(migraphx_shape* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shape);
shape(migraphx_shape* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a scalar shape /// Construct a scalar shape
shape(migraphx_shape_datatype_t type) shape(migraphx_shape_datatype_t type)
...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -479,10 +523,9 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
argument() {} argument() {}
argument(migraphx_argument* p, borrow) { this->set_handle(p, borrow{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(argument);
argument(migraphx_argument* p, own) { this->set_handle(p, own{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); } argument(const migraphx_argument* p) { this->set_handle(p, borrow{}); }
argument(shape pshape, void* pbuffer) argument(shape pshape, void* pbuffer)
...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument) ...@@ -494,7 +537,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_argument_shape, &pout, this->get_handle_ptr()); call(&migraphx_argument_shape, &pout, this->get_handle_ptr());
return {pout}; return {pout, this->share_handle()};
} }
char* data() const char* data() const
...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target) ...@@ -526,9 +569,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{ {
target() {} target() {}
target(migraphx_target* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(target);
target(migraphx_target* p, borrow) { this->set_handle(p, borrow{}); }
/// Construct a target from its name /// Construct a target from its name
target(const char* name) { this->make_handle(&migraphx_target_create, name); } target(const char* name) { this->make_handle(&migraphx_target_create, name); }
...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -538,15 +579,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
program_parameter_shapes() {} program_parameter_shapes() {}
program_parameter_shapes(migraphx_program_parameter_shapes* p, own) MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameter_shapes);
{
this->set_handle(p, own{});
}
program_parameter_shapes(migraphx_program_parameter_shapes* p, borrow)
{
this->set_handle(p, borrow{});
}
size_t size() const size_t size() const
{ {
...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -559,7 +592,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname); call(&migraphx_program_parameter_shapes_get, &pout, this->get_handle_ptr(), pname);
return {pout}; return {pout, this->share_handle()};
} }
std::vector<const char*> names() const std::vector<const char*> names() const
...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes) ...@@ -576,10 +609,9 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program /// A class to construct the inputs parameters for a program
struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
{ {
program_parameters(migraphx_program_parameters* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program_parameters);
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
MIGRAPHX_DEPRECATED("Contructor without lifetime annotation is deprecated.")
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); } program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); } program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters) ...@@ -604,9 +636,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
arguments(migraphx_arguments* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(arguments);
arguments(migraphx_arguments* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -619,27 +649,13 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_arguments_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_arguments* self;
argument operator()(size_t pidx) const
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
shapes(migraphx_shapes* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(shapes);
shapes(migraphx_shapes* p, borrow) { this->set_handle(p, borrow{}); }
size_t size() const size_t size() const
{ {
...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -652,26 +668,13 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
{ {
const_migraphx_shape_t pout; const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx); call(&migraphx_shapes_get, &pout, this->get_handle_ptr(), pidx);
return {pout}; return {pout, this->share_handle()};
} }
struct iterator_read
{
migraphx_shapes* self;
shape operator()(size_t pidx) const
{
const_migraphx_shape_t pout;
call(&migraphx_shapes_get, &pout, self, pidx);
return {pout};
}
};
}; };
struct operation : MIGRAPHX_HANDLE_BASE(operation) struct operation : MIGRAPHX_HANDLE_BASE(operation)
{ {
operation(migraphx_operation* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(operation);
operation(migraphx_operation* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
operation(const char* name, const char* attributes = nullptr, Ts... xs) operation(const char* name, const char* attributes = nullptr, Ts... xs)
...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation) ...@@ -689,15 +692,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction) struct instruction : MIGRAPHX_CONST_HANDLE_BASE(instruction)
{ {
instruction(migraphx_instruction* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(instruction);
}; };
struct instructions : MIGRAPHX_HANDLE_BASE(instructions) struct instructions : MIGRAPHX_HANDLE_BASE(instructions)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(instructions);
instructions(migraphx_instructions* p, own) { this->set_handle(p, own{}); }
instructions(migraphx_instructions* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
instructions(Ts... xs) instructions(Ts... xs)
...@@ -711,33 +711,36 @@ struct module; ...@@ -711,33 +711,36 @@ struct module;
struct modules : MIGRAPHX_HANDLE_BASE(modules) struct modules : MIGRAPHX_HANDLE_BASE(modules)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(modules);
modules(migraphx_modules* p, own) { this->set_handle(p, own{}); }
modules(migraphx_modules* p, borrow) { this->set_handle(p, borrow{}); }
template <class... Ts> template <class... Ts>
modules(Ts... xs) modules(Ts... xs)
{ {
std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.mm...}; std::array<migraphx_module_t, sizeof...(Ts)> a = {xs.get_handle_ptr()...};
this->make_handle(&migraphx_modules_create, a.data(), a.size()); this->make_handle(&migraphx_modules_create, a.data(), a.size());
} }
}; };
struct module struct module
{ {
migraphx_module_t mm; MIGRAPHX_DEPRECATED("Constructor without lifetime annotation is deprecated.")
module(migraphx_module* m) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(migraphx_module* m, borrow) : mm(std::shared_ptr<migraphx_module*>(), m) {}
module(const migraphx_module_t& m) : mm(m) {} template <class T>
module(migraphx_module* m, share<T> b) : mm(b.alias(m))
{
}
void print() const { call(&migraphx_module_print, mm); } void print() const { call(&migraphx_module_print, mm.get()); }
instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args) instruction add_instruction(const migraphx::operation& op, const migraphx::instructions& args)
{ {
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction, call(&migraphx_module_add_instruction,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr()); args.get_handle_ptr());
return instruction(op_ins, own{}); return instruction(op_ins, own{});
...@@ -750,7 +753,7 @@ struct module ...@@ -750,7 +753,7 @@ struct module
migraphx_instruction_t op_ins; migraphx_instruction_t op_ins;
call(&migraphx_module_add_instruction_with_mod_args, call(&migraphx_module_add_instruction_with_mod_args,
&op_ins, &op_ins,
mm, mm.get(),
op.get_handle_ptr(), op.get_handle_ptr(),
args.get_handle_ptr(), args.get_handle_ptr(),
module_args.get_handle_ptr()); module_args.get_handle_ptr());
...@@ -760,30 +763,53 @@ struct module ...@@ -760,30 +763,53 @@ struct module
instruction add_parameter(const std::string& name, shape s) instruction add_parameter(const std::string& name, shape s)
{ {
migraphx_instruction_t param_ins; migraphx_instruction_t param_ins;
call(&migraphx_module_add_parameter, &param_ins, mm, name.c_str(), s.get_handle_ptr()); call(
&migraphx_module_add_parameter, &param_ins, mm.get(), name.c_str(), s.get_handle_ptr());
return instruction(param_ins, own{}); return instruction(param_ins, own{});
} }
instruction add_return(const migraphx::instructions& args) instruction add_return(const migraphx::instructions& args)
{ {
migraphx_instruction_t ret_ins; migraphx_instruction_t ret_ins;
call(&migraphx_module_add_return, &ret_ins, mm, args.get_handle_ptr()); call(&migraphx_module_add_return, &ret_ins, mm.get(), args.get_handle_ptr());
return instruction(ret_ins, own{}); return instruction(ret_ins, own{});
} }
migraphx_module_t get_handle_ptr() const { return mm.get(); }
private:
std::shared_ptr<migraphx_module> mm;
}; };
struct context struct context
{ {
migraphx_context_t ctx; context(migraphx_context* p, borrow) : ctx(std::shared_ptr<migraphx_context*>(), p) {}
template <class T>
context(migraphx_context* p, share<T> b) : ctx(b.alias(p))
{
}
void finish() const { call(&migraphx_context_finish, ctx.get()); }
template <class T>
T get_queue()
{
void* out;
call(&migraphx_context_get_queue, &out, ctx.get());
// TODO: check type here
return reinterpret_cast<T>(out);
}
void finish() const { call(&migraphx_context_finish, ctx); } private:
std::shared_ptr<migraphx_context> ctx;
}; };
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
compile_options(migraphx_compile_options* p, own) { this->set_handle(p, own()); } MIGRAPHX_HANDLE_CONSTRUCTOR(compile_options);
/// For targets with offloaded memory(such as the gpu), this will insert /// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the /// instructions during compilation to copy the input parameters to the
...@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -807,9 +833,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() { this->make_handle(&migraphx_program_create); } program() { this->make_handle(&migraphx_program_create); }
program(migraphx_program* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(program);
program(migraphx_program* p, borrow) { this->set_handle(p, borrow{}); }
/// Compile the program for a specific target to be ran on /// Compile the program for a specific target to be ran on
void compile(const target& ptarget, const compile_options& poptions) const void compile(const target& ptarget, const compile_options& poptions) const
...@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -872,21 +896,21 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr()); call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
context experimental_get_context() context experimental_get_context()
{ {
migraphx_context_t ctx; migraphx_context_t ctx;
call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr()); call(&migraphx_program_experimental_get_context, &ctx, this->get_handle_ptr());
return context{ctx}; return context{ctx, this->share_handle()};
} }
module create_module(const std::string& name) module create_module(const std::string& name)
{ {
migraphx_module_t p_modu; migraphx_module_t p_modu;
call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data()); call(&migraphx_program_create_module, &p_modu, this->get_handle_ptr(), name.data());
return module{p_modu}; return module{p_modu, this->share_handle()};
} }
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
...@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -895,10 +919,9 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options // options for migraphx file format options
struct file_options : MIGRAPHX_HANDLE_BASE(file_options) struct file_options : MIGRAPHX_HANDLE_BASE(file_options)
{ {
MIGRAPHX_HANDLE_CONSTRUCTOR(file_options);
file_options() { this->make_handle(&migraphx_file_options_create); } file_options() { this->make_handle(&migraphx_file_options_create); }
file_options(migraphx_file_options* p, own) { this->set_handle(p, own()); }
// set file format // set file format
void set_file_format(const char* format) void set_file_format(const char* format)
{ {
...@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options) ...@@ -938,7 +961,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{ {
onnx_options() { this->make_handle(&migraphx_onnx_options_create); } onnx_options() { this->make_handle(&migraphx_onnx_options_create); }
onnx_options(migraphx_onnx_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(onnx_options);
/// Make onnx parser treat an inputs with a certain dimensions /// Make onnx parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options) ...@@ -1020,7 +1043,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{ {
tf_options() { this->make_handle(&migraphx_tf_options_create); } tf_options() { this->make_handle(&migraphx_tf_options_create); }
tf_options(migraphx_tf_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(tf_options);
/// Make tf parser treat an inputs with a certain dimensions /// Make tf parser treat an inputs with a certain dimensions
void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim) void set_input_parameter_shape(const std::string& name, std::vector<std::size_t> dim)
...@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names) ...@@ -1073,7 +1096,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{ {
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); } quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_op_names);
void add(const std::string& name) void add(const std::string& name)
{ {
...@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options) ...@@ -1098,12 +1121,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{ {
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); } quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); } MIGRAPHX_HANDLE_CONSTRUCTOR(quantize_int8_options);
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
/// Add an operator that should be quantized /// Add an operator that should be quantized
void add_op_name(const std::string& name) void add_op_name(const std::string& name)
......
...@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8', ...@@ -403,6 +403,7 @@ api.add_function('migraphx_quantize_int8',
@auto_handle(ref=True) @auto_handle(ref=True)
def context(h): def context(h):
h.method('finish', const=True) h.method('finish', const=True)
h.method('get_queue', returns='void*', fname='get_queue().unsafe_get')
@api.interface('migraphx_experimental_custom_op', @api.interface('migraphx_experimental_custom_op',
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct gathernd
{
int batch_dims = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}
std::string name() const { return "gathernd"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}
std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;
auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}
input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});
par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp> #include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp> #include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp> #include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp> #include <migraphx/op/gru.hpp>
......
...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"}, {"Flatten", "flatten"},
{"Floor", "floor"}, {"Floor", "floor"},
{"Gather", "gather"}, {"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"}, {"Identity", "identity"},
{"IsNaN", "isnan"}, {"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"}, {"LeakyRelu", "leaky_relu"},
......
...@@ -20,7 +20,7 @@ struct run_op : action<run_op> ...@@ -20,7 +20,7 @@ struct run_op : action<run_op>
auto op = make_op(name); auto op = make_op(name);
if(v.contains("fields")) if(v.contains("fields"))
op.from_value(v.at("fields")); op.from_value(v.at("fields"));
double t = time_op(ctx, op, inputs); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
}; };
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p) ...@@ -30,7 +30,7 @@ __global__ void kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
simple_reduce(${reduction}, ${init}, input, output, ${read}, ${write}); simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
}); });
} }
...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input ...@@ -57,6 +57,40 @@ static std::size_t get_reduce_elements(const std::vector<instruction_ref>& input
return get_reduce_elements(to_shapes(inputs)); return get_reduce_elements(to_shapes(inputs));
} }
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler> struct reduce_compiler : compiler<reduce_compiler>
{ {
std::vector<std::string> names() const std::vector<std::string> names() const
...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -68,20 +102,33 @@ struct reduce_compiler : compiler<reduce_compiler>
{ {
hip_compile_options options; hip_compile_options options;
auto reduce_elements = get_reduce_elements(inputs); auto reduce_elements = get_reduce_elements(inputs);
auto block_size = compute_block_size(reduce_elements, 256); auto algo = v.get("algo", get_reduce_algo(inputs));
options.set_launch_params( if(algo == "block")
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size); {
auto block_size = compute_block_size(reduce_elements, 256);
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements() * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
std::string identity = "[](auto x) { return x; }"; std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel, auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()}, {{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})}, {"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)}, {"read", v.get("read", identity)},
{"write", v.get("write", identity)}, {"write", v.get("write", identity)},
{"algo", algo},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const softmax_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
constexpr std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
}
struct softmax_compiler : compiler<softmax_compiler>
{
std::vector<std::string> names() const { return {"softmax"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto axis = v.at("axis").to<int64_t>();
auto relements = inputs[0].lens()[axis];
auto nelements = inputs.back().elements() / relements;
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "softmax_kernel";
auto src = interpolate_string(softmax_kernel, {{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt> template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first) constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{ {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp> #include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp> #include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
namespace migraphx { namespace migraphx {
...@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f) ...@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return integral_const_array<T, f(Xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... Xs, class F>
constexpr auto transform_i(integral_const_array<T, Xs...>, F f)
{
return sequence_c<sizeof...(Xs)>(
[=](auto... is) { return integral_const_array<T, f(Xs, is)...>{}; });
}
template <class T, T... Xs, class U, U... Ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
......
...@@ -42,6 +42,32 @@ struct print_buffer ...@@ -42,6 +42,32 @@ struct print_buffer
pos++; pos++;
} }
} }
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M> template <size_t M>
constexpr void append(const char (&array)[M]) constexpr void append(const char (&array)[M])
...@@ -54,14 +80,36 @@ struct print_buffer ...@@ -54,14 +80,36 @@ struct print_buffer
template <class... Ts> template <class... Ts>
__host__ __device__ void print(const Ts&... xs) __host__ __device__ void print(const Ts&... xs)
{ {
const auto size = (sizeof(xs) + ...); print_buffer<1024> buffer;
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...}; swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer); printf("%s", buffer.buffer);
} }
} // namespace debug } // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken // noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4> template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct ...@@ -73,20 +121,38 @@ assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& funct
abort(); abort();
} }
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \ #define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__)) }(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK #define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK #define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false) #define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else #else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume #define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable #define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif #endif
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/array.hpp> #include <migraphx/kernels/integral_constant.hpp>
namespace migraphx { namespace migraphx {
......
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -124,8 +124,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
#endif #endif
template <class Input, class T, class Output> template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i, Output) constexpr auto reduce_slice(Input input, T i)
{ {
constexpr auto lens = transform(get_shape_c<Input>{}.lens, constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens, get_shape_c<Output>{}.lens,
...@@ -136,23 +136,160 @@ constexpr auto reduce_slice(Input input, T i, Output) ...@@ -136,23 +136,160 @@ constexpr auto reduce_slice(Input input, T i, Output)
}); });
; ;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides); constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s); return make_tensor_view(&input[i], s);
} }
template <class Op, class T, class Input, class Output, class ReadInput, class WriteOuput> namespace reduce {
template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
return [=](auto x, auto... xs) {
// TODO: assert all elements are the same
return f(slicer(x), slicer(xs)...);
};
}
template <class Input, index_int Axis>
constexpr auto compute_reduce_axis()
{
constexpr auto lens =
transform_i(get_shape_c<Input>{}.lens, [](index_int x, index_int i) -> index_int {
if(i == Axis)
return 1;
return x;
});
return make_shape(lens, get_shape_c<Input>{}.strides);
}
template <class Input, index_int Axis>
using with_axis = decltype(compute_reduce_axis<Input, Axis>());
struct block
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return block_reduce(idx, op, init, x.get_shape().elements(), [&](auto j) {
return read(x[j], xs[j]...);
});
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
});
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
struct lane
{
template <class Slicer>
struct reducer
{
index idx;
Slicer slicer;
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
using type = typename decltype(x)::type;
type r = init;
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
r = op(r, read(x[j], xs[j]...));
}
return r;
});
}
template <class F>
__device__ void outer(F f) const
{
f();
}
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
for(index_int j = 0; j < x.get_shape().elements(); j++)
{
f(x[j], xs[j]...);
}
});
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements, [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i);
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
} // namespace reduce
template <class Algo,
class Op,
class T,
class Input,
class Output,
class ReadInput,
class WriteOuput>
__device__ void __device__ void
simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write) simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOuput write)
{ {
auto idx = make_index(); Algo::template run<Output>([&](auto out_idx, auto r) {
constexpr auto nelements = get_shape_c<Output>{}.elements(); auto x = r.reduce(op, init, read)(input);
constexpr auto relements = get_shape_c<Input>{}.elements() / get_shape_c<Output>{}.elements(); r.outer([&] { output[out_idx] = write(x); });
idx.global_stride(nelements * idx.nlocal(), [&](auto i) {
const auto out_idx = output.get_shape().multi(i / idx.nlocal());
auto rs = reduce_slice(input, out_idx, output);
MIGRAPHX_ASSERT(relements == rs.get_shape().elements());
auto r = block_reduce(idx, op, init, relements, [&](auto j) { return read(rs[j]); });
if(idx.local == 0)
output[out_idx] = write(r);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment