Commit 7e297b13 authored by Paul's avatar Paul
Browse files

Merge

parents 86ea5e91 aa7ff911
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args});
}
}
} // namespace migraphx
)__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe_if(bool b)
{
if(b)
return 256;
else
return 1;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
v["reduction"] = "op::sum{}";
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}";
}
else if(op.name() == "reduce_max")
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
}
else if(op.name() == "reduce_prod")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
_c<bool{IS_AVG_POOLING}>,
_c<int64_t{SAMPLING_RATIO}>,
MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE}));
roialign(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct roialign_compiler : compiler<roialign_compiler>
{
std::vector<std::string> names() const { return {"roialign"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 128);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "roialign_kernel";
// sampling_ratio
options.params += " -DSAMPLING_RATIO=" + v.at("sampling_ratio").to<std::string>();
// pooling_mode
auto mode = v.at("mode").to<migraphx::op::pooling_mode>();
std::string is_avg_pooling =
(mode == migraphx::op::pooling_mode::average) ? "true" : "false";
options.params += " -DIS_AVG_POOLING=" + is_avg_pooling;
// coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale
options.params += " -DSPATIAL_SCALE=" + v.at("spatial_scale").to<std::string>();
return compile_hip_code_object(roialign_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
});
}
}
} // namespace migraphx
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(ctx,
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& op) const
{
return [=](module& m, instruction_ref ins) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
void* kernargs,
std::size_t size)
{
assert(global > 0);
assert(local > 0);
void* config[] = {
// HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY
......
......@@ -21,6 +21,26 @@ struct greater
}
};
template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
while(first != last)
{
*d_first++ = *first++;
}
return d_first;
}
template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{
......@@ -96,6 +116,35 @@ constexpr Iterator1 search(Iterator1 first, Iterator1 last, Iterator2 s_first, I
}
}
template <class InputIt1, class InputIt2, class T, class BinaryOperation1, class BinaryOperation2>
constexpr T inner_product(InputIt1 first1,
InputIt1 last1,
InputIt2 first2,
T init,
BinaryOperation1 op1,
BinaryOperation2 op2)
{
while(first1 != last1)
{
init = op1(init, op2(*first1, *first2));
++first1;
++first2;
}
return init;
}
template <class InputIt1, class InputIt2, class T>
constexpr T inner_product(InputIt1 first1, InputIt1 last1, InputIt2 first2, T init)
{
return inner_product(
first1,
last1,
first2,
init,
[](auto x, auto y) { return x + y; },
[](auto x, auto y) { return x * y; });
}
} // namespace migraphx
#endif
......@@ -2,40 +2,51 @@
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_ARRAY_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
constexpr array& operator op(const array& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
constexpr array& operator op(const T& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
friend constexpr array operator binary_op(const array& x, const array& y) \
{ \
auto z = x; \
return z op y; \
} \
friend constexpr array operator binary_op(const array& x, const T& y) \
{ \
auto z = x; \
return z op y; \
} \
friend constexpr array operator binary_op(const T& x, const array& y) \
{ \
for(index_int i = 0; i < N; i++) \
y[i] = x op y[i]; \
return y; \
#define MIGRAPHX_DEVICE_ARRAY_OP(op, binary_op) \
template <class U> \
constexpr array& operator op(const array<U, N>& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x[i]; \
return *this; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
constexpr array& operator op(const U& x) \
{ \
for(index_int i = 0; i < N; i++) \
d[i] op x; \
return *this; \
} \
template <class U> \
friend constexpr auto operator binary_op(const array& x, const array<U, N>& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y[i]; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const array& x, const U& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x[i] binary_op y; \
return z; \
} \
template <class U, MIGRAPHX_REQUIRES(is_convertible<U, T>{})> \
friend constexpr auto operator binary_op(const U& x, const array& y) \
{ \
array<decltype(T {} binary_op U{}), N> z{}; \
for(index_int i = 0; i < N; i++) \
z[i] = x binary_op y[i]; \
return z; \
}
template <class T, index_int N>
......@@ -63,6 +74,7 @@ struct array
constexpr const T* data() const { return d; }
constexpr index_constant<N> size() const { return {}; }
constexpr auto empty() const { return size() == _c<0>; }
constexpr T* begin() { return d; }
constexpr const T* begin() const { return d; }
......@@ -134,8 +146,8 @@ struct array
constexpr array carry(array result) const
{
uint32_t overflow = 0;
for(std::ptrdiff_t i = result.size() - 1; i > 0; i--)
index_int overflow = 0;
for(diff_int i = result.size() - 1; i > 0; i--)
{
auto z = result[i] + overflow;
// Reset overflow
......@@ -165,23 +177,23 @@ struct array
}
};
template <class T, T... xs>
struct integral_const_array : array<T, sizeof...(xs)>
template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)>
{
using base_array = array<T, sizeof...(xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {}
using base_array = array<T, sizeof...(Xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}
};
template <class T, T... xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f)
template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, Xs...>, F f)
{
return integral_const_array<T, f(xs)...>{};
return integral_const_array<T, f(Xs)...>{};
}
template <class T, T... xs, class U, U... ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f)
template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{
return integral_const_array<T, f(xs, ys)...>{};
return integral_const_array<T, f(Xs, Ys)...>{};
}
template <index_int... Ns>
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN
#else
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <class T, class = decltype(T{} % 10, -T{})>
constexpr void append(T i)
{
if(i < 0)
{
append('-');
i = -i;
}
char c = (i % 10) + '0';
if(i > 9)
append(i / 10);
append(c);
}
constexpr void append(const char* str)
{
if(str == nullptr)
return;
int i = 512;
while(*str != 0 and i > 0)
{
append(*str);
str++;
i--;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
print_buffer<1024> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
struct source_location
{
int line = __builtin_LINE();
const char* file = __builtin_FILE();
const char* function = __builtin_FUNCTION();
};
template <class T>
struct source_location_capture
{
T x;
source_location loc;
template <class U, class = decltype(T(U{}))>
constexpr source_location_capture(U px, source_location ploc = source_location{})
: x(px), loc(ploc)
{
}
constexpr operator source_location() const { return loc; }
constexpr operator T() const { return x; }
};
// noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{
// printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort();
}
template <class... Ts>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_location& loc,
Ts... xs)
{
debug::print(loc.file, ":", loc.line, ": ", loc.function, ": error: ", xs..., "\n");
abort();
}
// NOLINTNEXTLINE
#define MIGRAPHX_ASSERT_FAIL(cond, ...) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(__VA_ARGS__))
// NOLINTNEXTLINE
#define MIGRAPHX_CHECK(cond) \
MIGRAPHX_ASSERT_FAIL(cond, #cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \
assert_fail(xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) source_location_capture<T>
#define MIGRAPHX_WARN(cond, loc, ...) MIGRAPHX_ASSERT_FAIL(cond, loc, __VA_ARGS__)
#define MIGRAPHX_ASSERT MIGRAPHX_CHECK
#define MIGRAPHX_ASSUME MIGRAPHX_CHECK
#define MIGRAPHX_UNREACHABLE() MIGRAPHX_ASSERT(false)
#else
// NOLINTNEXTLINE
#define MIGRAPHX_CAPTURE_SOURCE_LOCATION(T) T
#define MIGRAPHX_ASSUME __builtin_assume
#define MIGRAPHX_UNREACHABLE __builtin_unreachable
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_WARN(...)
#endif
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_DFOR_HPP
namespace migraphx {
// Multidimensional for loop
inline constexpr auto dfor()
{
return [](auto f) { f(); };
}
template <class T, class... Ts>
constexpr auto dfor(T x, Ts... xs)
{
return [=](auto f) {
for(T i = 0; i < x; i++)
{
dfor(xs...)([&](Ts... is) { f(i, is...); });
}
};
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_DPP_HPP
#define MIGRAPHX_GUARD_KERNELS_DPP_HPP
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/debug.hpp>
namespace migraphx {
#ifndef MIGRAPHX_HAS_DPP
#define MIGRAPHX_HAS_DPP 1
#endif
#if MIGRAPHX_HAS_DPP
constexpr unsigned int dpp_row_shr(unsigned int x) { return 0x110u | x; }
constexpr unsigned int dpp_row_bcast(unsigned int x)
{
unsigned int y = 0;
switch(x)
{
case 15: y = 0x142; break;
case 31: y = 0x143; break;
default: MIGRAPHX_UNREACHABLE();
}
return y;
}
template <unsigned int DppCtrl,
unsigned int RowMask = 0xf,
unsigned int BankMask = 0xf,
bool BoundCtrl = false,
class T>
__device__ T dpp_mov(T& x)
{
static const index_int n = sizeof(T) < 4 ? 1 : sizeof(T) / 4;
union type
{
uint32_t reg[n];
T data;
};
type output{};
type input{};
// cppcheck-suppress unreadVariable
input.data = x;
for(index_int i = 0; i < n; i++)
{
output.reg[i] = __hip_move_dpp(input.reg[i], DppCtrl, RowMask, BankMask, BoundCtrl);
}
return output.data;
}
#endif
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DPP_HPP
......@@ -3,6 +3,14 @@
#include <migraphx/kernels/array.hpp>
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
namespace migraphx {
struct swallow
......@@ -16,6 +24,19 @@ struct swallow
template <index_int>
using ignore = swallow;
template <class... Fs>
struct overloaded : Fs...
{
using Fs::operator()...;
overloaded(Fs... fs) : Fs(fs)... {}
};
template <class... Fs>
overloaded<Fs...> overload(Fs... fs)
{
return {fs...};
}
namespace detail {
template <class R>
......@@ -116,7 +137,7 @@ constexpr auto by(F f)
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
swallow{(f(static_cast<Ts&&>(xs)), 0)...};
}
template <class F>
......@@ -124,12 +145,60 @@ constexpr void each_args(F)
{
}
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts>
auto pack(Ts... xs)
constexpr auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <class G, class F>
constexpr auto join(G g, F f)
{
return f([=](auto... xs) { return g(xs...); });
}
template <class G, class F, class... Fs>
constexpr auto join(G g, F f, Fs... fs)
{
return f([=](auto... xs) { return join([=](auto... ys) { return g(xs..., ys...); }, fs...); });
}
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N>
constexpr auto arg_c()
{
......@@ -142,34 +211,45 @@ constexpr auto arg(IntegralConstant ic)
return arg_c<ic>();
}
inline constexpr auto rotate_last()
template <class F>
constexpr auto make_transform(F f)
{
return [](auto... xs) {
return [=](auto&& f) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
};
};
return [=](auto... xs) { return [=](auto g) { return f(g, xs...); }; };
}
// An arg transformation takes the arguments and then a function to take the new arguments:
// transform(xs...)([](auto... ys) { ... })
// The transform_args function takes a list of transformations and continually applies them
template <class F>
constexpr auto transform_args(F f)
{
return [=](auto... xs) {
return [=](auto g) { return f(xs...)([&](auto... ys) { return g(ys...); }); };
};
return f;
}
template <class F, class... Fs>
constexpr auto transform_args(F f, Fs... fs)
{
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
return make_transform([=](auto g, auto... xs) {
return f(xs...)([=](auto... ys) { return transform_args(fs...)(ys...)(g); });
});
}
#define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
// identity transform
inline constexpr auto transform_args()
{
return make_transform([](auto f, auto... xs) { return f(xs...); });
}
// Rotate the first argument to the last argument
inline constexpr auto rotate_last()
{
return make_transform([](auto f, auto... xs) {
return sequence_c<sizeof...(xs)>([&](auto... is) {
constexpr auto size = sizeof...(is);
return f(arg_c<(is + size - 1) % size>()(xs...)...);
});
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T>
struct gathernd_settings
{
T batch_dims{};
};
template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();
auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;
ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;
auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
namespace migraphx {
template <class F>
struct generic_constant
{
static constexpr auto value = F{}();
using value_type = decltype(value);
using type = generic_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
template <class F>
constexpr generic_constant<F> make_generic_constant(F)
{
return {};
}
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_CONSTANT(x) \
make_generic_constant([] { \
struct fun \
{ \
constexpr auto operator()() const { return x; } \
}; \
return fun{}; \
}())
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
namespace migraphx {
......@@ -12,23 +13,23 @@ struct index
index_int local = 0;
index_int group = 0;
__device__ index_int nglobal() const
{
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; }
#else
return blockDim.x * gridDim.x;
#endif
__device__ index_int nglobal() const
{
return blockDim.x * gridDim.x; // NOLINT
}
#endif
__device__ index_int nlocal() const
{
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; }
#else
return blockDim.x;
#endif
__device__ index_int nlocal() const
{
return blockDim.x; // NOLINT
}
#endif
template <class F>
__device__ void global_stride(index_int n, F f) const
......@@ -53,7 +54,7 @@ struct index
inline __device__ index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}
} // namespace migraphx
......
......@@ -5,28 +5,31 @@
namespace migraphx {
template <class T, T v>
template <class T, T V>
struct integral_constant
{
static constexpr T value = v;
static constexpr T value = V;
using value_type = T;
using type = integral_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
};
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \
template <class T, T V, class U, U w> \
constexpr inline integral_constant<decltype(V op w), (V op w)> operator op( \
integral_constant<T, V>, integral_constant<U, w>) noexcept \
{ \
return {}; \
}
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \
integral_constant<T, v>) noexcept \
template <class T, T V> \
constexpr inline integral_constant<decltype(op V), (op V)> operator op( \
integral_constant<T, V>) noexcept \
{ \
return {}; \
}
......@@ -45,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
......@@ -64,8 +67,14 @@ using false_type = bool_constant<false>;
template <index_int N>
using index_constant = integral_constant<index_int, N>;
template <auto v>
static constexpr auto _c = integral_constant<decltype(v), v>{};
template <auto V>
static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
template <class F>
constexpr auto return_c(F f)
{
return _c<f()>;
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#define MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
namespace migraphx {
template <class F, class Iterator = diff_int>
struct basic_iota_iterator
{
Iterator index;
F f;
using difference_type = diff_int;
using reference = decltype(f(declval<Iterator>()));
using value_type = remove_reference_t<reference>;
using pointer = add_pointer_t<value_type>;
constexpr basic_iota_iterator& operator+=(diff_int n)
{
index += n;
return *this;
}
constexpr basic_iota_iterator& operator-=(diff_int n)
{
index -= n;
return *this;
}
constexpr basic_iota_iterator& operator++()
{
index++;
return *this;
}
constexpr basic_iota_iterator& operator--()
{
index--;
return *this;
}
constexpr basic_iota_iterator operator++(int) // NOLINT
{
basic_iota_iterator it = *this;
index++;
return it;
}
constexpr basic_iota_iterator operator--(int) // NOLINT
{
basic_iota_iterator it = *this;
index--;
return it;
}
// TODO: operator->
constexpr reference operator*() const { return f(index); }
template <class T>
constexpr reference operator[](T x) const
{
return f(index + x);
}
};
template <class T, class F>
constexpr basic_iota_iterator<F, T> make_basic_iota_iterator(T x, F f)
{
return basic_iota_iterator<F, T>{x, f};
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x += y;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator+(diff_int x, basic_iota_iterator<F, Iterator> y)
{
return y + x;
}
template <class F, class Iterator>
constexpr diff_int operator-(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index - y.index;
}
template <class F, class Iterator>
constexpr basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x, diff_int y)
{
return x -= y;
}
template <class F, class Iterator>
constexpr bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index == y.index;
}
template <class F, class Iterator>
constexpr bool operator!=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index != y.index;
}
template <class F, class Iterator>
constexpr bool operator<(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index < y.index;
}
template <class F, class Iterator>
constexpr bool operator>(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index > y.index;
}
template <class F, class Iterator>
constexpr bool operator>=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index >= y.index;
}
template <class F, class Iterator>
constexpr bool operator<=(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{
return x.index <= y.index;
}
struct defaul_iota_iterator
{
template <class T>
constexpr auto operator()(T x) const
{
return x;
}
};
using iota_iterator = basic_iota_iterator<defaul_iota_iterator>;
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_IOTA_ITERATOR_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
template <class T>
constexpr T as_float(T x)
{
return x;
}
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF2(name, fname) \
template <class... Ts> \
auto __device__ name(migraphx::vec<migraphx::half, 2> x, Ts... xs) \
MIGRAPHX_RETURNS(migraphx::vec<migraphx::half, 2>{fname(x, xs...)}); \
template <class... Ts, index_int N, MIGRAPHX_REQUIRES(N % 2 == 0 && (N > 2))> \
auto __device__ name(migraphx::vec<migraphx::half, N> x, Ts... xs) \
{ \
return vec_packed_transform<2>(x, xs...)( \
[](auto... ys) -> migraphx::vec<migraphx::half, 2> { return fname(ys...); }); \
}
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH(asin, ::asin)
MIGRAPHX_DEVICE_MATH(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH(atan, ::atan)
MIGRAPHX_DEVICE_MATH(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH(cos, ::cos)
MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin)
MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf)
MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf)
MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf)
MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf)
MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf)
MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf)
MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, min, ::min)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(double, min, ::min)
// Add overloads for half that calls the float version
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, max, ::fmaxf)
MIGRAPHX_DEVICE_MATH_BINARY_FOR(migraphx::half, min, ::fminf)
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto max(const T& a, const T& b)
{
return where(a < b, b, a);
}
template <class T, MIGRAPHX_REQUIRES(not is_any_vec<T>())>
constexpr auto min(const T& a, const T& b)
{
return where(a < b, a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto max(const T& a, const U& b)
{
return max<common_type_t<T, U>>(a, b);
}
template <class T, class U, MIGRAPHX_REQUIRES(not is_same<T, U>{} and not is_any_vec<T, U>())>
constexpr auto min(const T& a, const U& b)
{
return min<common_type_t<T, U>>(a, b);
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_OPS_HPP
#define MIGRAPHX_GUARD_KERNELS_OPS_HPP
#include <migraphx/kernels/math.hpp>
namespace migraphx {
namespace op {
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct mean
{
index_int item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return migraphx::max(x, y);
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return migraphx::min(x, y);
}
};
} // namespace op
struct lowest
{
template <class T>
constexpr operator T() const
{
return numeric_lowest<T>();
}
};
struct highest
{
template <class T>
constexpr operator T() const
{
return numeric_max<T>();
}
};
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_OPS_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment