Commit fae5170b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_op_name

parents a3906038 2a79a9ff
...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__( ...@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp> #include <migraphx/kernels/pointwise.hpp>
#include <args.hpp> #include <args.hpp>
using namespace migraphx; namespace migraphx {
${preamble} ${preamble}
...@@ -32,6 +32,8 @@ __global__ void kernel(${params}) ...@@ -32,6 +32,8 @@ __global__ void kernel(${params})
} }
} // namespace migraphx
int main() {} int main() {}
)__migraphx__"; )__migraphx__";
...@@ -46,7 +48,7 @@ operation compile_pointwise(context&, ...@@ -46,7 +48,7 @@ operation compile_pointwise(context&,
options.local = 1024; options.local = 1024;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")}, {{"params", enum_params(inputs.size(), "void * private_p")},
...@@ -60,8 +62,17 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -60,8 +62,17 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
{ {
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"})); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
return compile_pointwise((ctx), inputs, "&" + name, g.str()); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
} }
} // namespace gpu } // namespace gpu
......
...@@ -14,17 +14,29 @@ namespace gpu { ...@@ -14,17 +14,29 @@ namespace gpu {
static const char* const roialign_kernel = R"__migraphx__( static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp> #include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
using namespace migraphx; namespace migraphx {
extern "C" { extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y) __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{ {
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { roialign(xs...); }); make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
_c<bool{IS_AVG_POOLING}>,
_c<int64_t{SAMPLING_RATIO}>,
MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE}));
roialign(xs..., settings);
});
} }
} }
} // namespace migraphx
int main() {} int main() {}
)__migraphx__"; )__migraphx__";
...@@ -38,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const ...@@ -38,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
options.inputs = io_shapes; options.inputs = io_shapes;
options.output = out_s; options.output = out_s;
options.kernel_name = "roialign_kernel"; options.kernel_name = "roialign_kernel";
options.reduced_inputs = io_shapes; options.virtual_inputs = io_shapes;
// sampling_ratio // sampling_ratio
assert(val.contains("sampling_ratio")); assert(val.contains("sampling_ratio"));
......
...@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype( ...@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024) inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{ {
index_int groups = (n + local - 1) / local; index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(256, groups) * local; // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) __device__ { launch(stream, nglobal, local)([=](auto idx) __device__ {
......
...@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 256; const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
batch_shape.elements() * block_size, type init = lowest();
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size); if(axis == batch_lens.size() - 1)
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; {
type init = lowest(); gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto batch_max = block_reduce<max_block_size>( auto start_loc = i / block_size * batch_item_num;
idx, max{}, init, batch_item_num, [&](auto j) __device__ { auto batch_max = block_reduce<max_block_size>(
data_idx[axis] = j; idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[data_idx]; return input[start_loc + j];
}); });
auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val));
});
auto batch_sum = idx.local_stride(batch_item_num, [&](auto j) __device__ {
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { auto val = input[start_loc + j] - batch_max;
data_idx[axis] = j; output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
auto val = input[data_idx] - batch_max; });
return ::exp(to_hip_type(val));
}); });
}
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
idx.local_stride(batch_item_num, [&](auto j) __device__ { auto batch_sum = block_reduce<max_block_size>(
data_idx[axis] = j; idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[data_idx] - batch_max; data_idx[axis] = j;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum; auto val = input[data_idx] - batch_max;
}); return ::exp(to_hip_type(val));
}); });
idx.local_stride(batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
});
}
}); });
} }
......
...@@ -62,6 +62,8 @@ struct fusion ...@@ -62,6 +62,8 @@ struct fusion
keep_alive(std::move(t)); keep_alive(std::move(t));
} }
bool empty() const { return fp == nullptr; }
op_t operator[](std::size_t i) const op_t operator[](std::size_t i) const
{ {
assert(fp); assert(fp);
...@@ -125,12 +127,11 @@ struct fusion ...@@ -125,12 +127,11 @@ struct fusion
return shape{shape::int8_type, {ws_size}}; return shape{shape::int8_type, {ws_size}};
} }
void compile(context& ctx) bool compile(context& ctx)
{ {
assert(fp); assert(fp);
auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()); return miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()) ==
if(status != miopenStatusSuccess) miopenStatusSuccess;
MIGRAPHX_THROW("Compiling fusion plan failed");
} }
argument execute(context& ctx, argument execute(context& ctx,
...@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins) ...@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{ {
const auto device_name = split_string(get_device_name(), ':').front(); const auto device_name = trim(split_string(get_device_name(), ':').front());
if(not contains(get_supported_archs(), device_name)) if(not contains(get_supported_archs(), device_name))
return false; return false;
if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{})) if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
...@@ -561,6 +562,117 @@ struct find_mul_add_relu ...@@ -561,6 +562,117 @@ struct find_mul_add_relu
} }
}; };
struct miopen_fusion
{
struct fuse_op_data
{
operation op;
float alpha = 1;
float beta = 0;
};
struct fuse_op : fuse_op_data, reflect_equality<fuse_op>, reflect_stream<fuse_op>
{
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
std::vector<fuse_op> ops = {};
fusion f = {};
std::function<void(context&, const fusion&, const std::vector<argument>&)> execute;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ops, "ops"));
}
value compile(context& ctx, const shape&, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
std::size_t i = 0;
f = fusion(inputs[i]);
i++;
std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>>
invokers;
for(auto&& fop : ops)
{
if(i > inputs.size())
{
f = {};
return {};
}
if(fop.op.name() == "convolution")
{
auto* mop = f.create_conv(any_cast<op::convolution>(fop.op), inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsConvForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "add")
{
auto* mop = f.create_bias(inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsBiasForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "relu")
{
auto* mop = f.create_relu();
invokers.push_back([=](const fused_operator_args& fargs,
const std::vector<argument>&) {
miopenSetOpArgsActivForward(fargs.get(), mop, &fop.alpha, &fop.beta, 0, 0, 0);
});
}
else
{
f = {};
return {};
}
}
if(not f.compile(ctx))
{
f = {};
return {};
}
execute = [invokers](context& c, const fusion& ff, const std::vector<argument>& args) {
auto fargs = make_fused_args();
for(auto&& invoker : invokers)
invoker(fargs, args);
ff.execute(c, fargs, args.front(), args.back());
};
return {{"workspace", f.get_workspace(ctx).bytes()}};
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{
if(not f.empty())
return;
auto v = compile(ctx, output_shape, inputs);
if(not v.is_object())
MIGRAPHX_THROW("Failed to compile fusion plan");
}
std::string name() const { return "gpu::miopen_fusion"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(ops.empty())
return {};
// TODO: Check number of arguments
return ops.front().op.compute_shape({inputs[0], inputs[1]});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
execute(ctx, f, args);
return args.back();
}
};
struct miopen_conv_bias struct miopen_conv_bias
{ {
op::convolution op; op::convolution op;
...@@ -596,7 +708,8 @@ struct miopen_conv_bias ...@@ -596,7 +708,8 @@ struct miopen_conv_bias
f = fusion(inputs[0]); f = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]); conv = f.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]); bias = f.create_bias(inputs[3]);
f.compile(ctx); if(not f.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan");
} }
shape get_workspace(context& ctx) { return f.get_workspace(ctx); } shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
...@@ -683,6 +796,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r) ...@@ -683,6 +796,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op")
return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s);
});
}
template <class... Ms>
auto conv_bias_pointwise(Ms... ms)
{
return precompile_name("pointwise")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
struct find_conv_bias struct find_conv_bias
{ {
context* ctx = nullptr; context* ctx = nullptr;
...@@ -709,6 +841,46 @@ struct find_conv_bias_relu ...@@ -709,6 +841,46 @@ struct find_conv_bias_relu
} }
}; };
struct find_conv_pointwise
{
context* ctx = nullptr;
auto matcher() const
{
return precompile_name("pointwise")(
match::nargs(3),
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")));
}
void apply(module& m, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
module_ref pm = ins->module_inputs().front();
miopen_fusion op{};
op.ops.push_back({{conv_op}});
for(auto&& i : *pm)
{
if(i.name()[0] == '@')
continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}});
}
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(inputs));
if(not v.is_object())
return;
m.replace_instruction(ins, op, inputs);
}
};
struct find_gemm_add struct find_gemm_add
{ {
auto matcher() const auto matcher() const
...@@ -778,6 +950,7 @@ void fuse_ops::apply(module& p) const ...@@ -778,6 +950,7 @@ void fuse_ops::apply(module& p) const
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_layernorm{}, find_layernorm{},
find_conv_pointwise{ctx},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{}, find_add_gelu{},
......
...@@ -16,7 +16,7 @@ struct hip_compile_options ...@@ -16,7 +16,7 @@ struct hip_compile_options
shape output; shape output;
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> reduced_inputs = {}; std::vector<shape> virtual_inputs = {};
}; };
operation compile_hip_code_object(const std::string& content, hip_compile_options options); operation compile_hip_code_object(const std::string& content, hip_compile_options options);
......
...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t batch_item_num = batch_lens[axis]; size_t batch_item_num = batch_lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens}; migraphx::shape batch_shape{arg_shape.type(), batch_lens};
migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, std_arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto* output = device_cast(result.get<int64_t>().data()); auto* output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
......
...@@ -176,23 +176,23 @@ struct array ...@@ -176,23 +176,23 @@ struct array
} }
}; };
template <class T, T... xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
using base_array = array<T, sizeof...(xs)>; using base_array = array<T, sizeof...(Xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {} MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}
}; };
template <class T, T... xs, class F> template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, F f)
{ {
return integral_const_array<T, f(xs)...>{}; return integral_const_array<T, f(Xs)...>{};
} }
template <class T, T... xs, class U, U... ys, class F> template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f) constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{ {
return integral_const_array<T, f(xs, ys)...>{}; return integral_const_array<T, f(Xs, Ys)...>{};
} }
template <index_int... Ns> template <index_int... Ns>
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP #define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
inline __host__ __device__ void #define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function) #define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN
#else
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
// noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{ {
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion); // printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort(); abort();
} }
#ifdef MIGRAPHX_DEBUG #ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \ #define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \ ((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(xs...); \ assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__)) }(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#else #else
#define MIGRAPHX_ASSERT(cond) #define MIGRAPHX_ASSERT(cond)
#endif #endif
......
...@@ -16,6 +16,19 @@ struct swallow ...@@ -16,6 +16,19 @@ struct swallow
template <index_int> template <index_int>
using ignore = swallow; using ignore = swallow;
template <class... Fs>
struct overloaded : Fs...
{
using Fs::operator()...;
overloaded(Fs... fs) : Fs(fs)... {}
};
template <class... Fs>
overloaded<Fs...> overload(Fs... fs)
{
return {fs...};
}
namespace detail { namespace detail {
template <class R> template <class R>
...@@ -124,12 +137,48 @@ constexpr void each_args(F) ...@@ -124,12 +137,48 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts> template <class... Ts>
auto pack(Ts... xs) constexpr auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N> template <index_int N>
constexpr auto arg_c() constexpr auto arg_c()
{ {
...@@ -168,8 +217,13 @@ constexpr auto transform_args(F f, Fs... fs) ...@@ -168,8 +217,13 @@ constexpr auto transform_args(F f, Fs... fs)
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); }; return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
} }
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); }) [](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP #endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
namespace migraphx {
template <class F>
struct generic_constant
{
static constexpr auto value = F{}();
using value_type = decltype(value);
using type = generic_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
template <class F>
constexpr generic_constant<F> make_generic_constant(F)
{
return {};
}
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_CONSTANT(x) \
make_generic_constant([] { \
struct fun \
{ \
constexpr auto operator()() const { return x; } \
}; \
return fun{}; \
}())
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP #ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP #define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp> #include <migraphx/kernels/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -17,7 +17,7 @@ struct index ...@@ -17,7 +17,7 @@ struct index
#ifdef MIGRAPHX_NGLOBAL #ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL; return MIGRAPHX_NGLOBAL;
#else #else
return blockDim.x * gridDim.x; return blockDim.x * gridDim.x; // NOLINT
#endif #endif
} }
...@@ -26,7 +26,7 @@ struct index ...@@ -26,7 +26,7 @@ struct index
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL; return MIGRAPHX_NLOCAL;
#else #else
return blockDim.x; return blockDim.x; // NOLINT
#endif #endif
} }
...@@ -53,7 +53,7 @@ struct index ...@@ -53,7 +53,7 @@ struct index
inline __device__ index make_index() inline __device__ index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -5,28 +5,31 @@ ...@@ -5,28 +5,31 @@
namespace migraphx { namespace migraphx {
template <class T, T v> template <class T, T V>
struct integral_constant struct integral_constant
{ {
static constexpr T value = v; static constexpr T value = V;
using value_type = T; using value_type = T;
using type = integral_constant; using type = integral_constant;
constexpr operator value_type() const noexcept { return value; } constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; } constexpr value_type operator()() const noexcept { return value; }
static constexpr type to() { return {}; }
}; };
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \ #define MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(op) \
template <class T, T v, class U, U w> \ template <class T, T V, class U, U w> \
constexpr inline integral_constant<decltype(v op w), (v op w)> operator op( \ constexpr inline integral_constant<decltype(V op w), (V op w)> operator op( \
integral_constant<T, v>, integral_constant<U, w>) noexcept \ integral_constant<T, V>, integral_constant<U, w>) noexcept \
{ \ { \
return {}; \ return {}; \
} }
// NOLINTNEXTLINE
#define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \ #define MIGRAPHX_INTEGRAL_CONSTANT_UNARY_OP(op) \
template <class T, T v> \ template <class T, T V> \
constexpr inline integral_constant<decltype(op v), (op v)> operator op( \ constexpr inline integral_constant<decltype(op V), (op V)> operator op( \
integral_constant<T, v>) noexcept \ integral_constant<T, V>) noexcept \
{ \ { \
return {}; \ return {}; \
} }
...@@ -64,8 +67,8 @@ using false_type = bool_constant<false>; ...@@ -64,8 +67,8 @@ using false_type = bool_constant<false>;
template <index_int N> template <index_int N>
using index_constant = integral_constant<index_int, N>; using index_constant = integral_constant<index_int, N>;
template <auto v> template <auto V>
static constexpr auto _c = integral_constant<decltype(v), v>{}; static constexpr auto _c = integral_constant<decltype(V), V>{}; // NOLINT
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP #endif // MIGRAPHX_GUARD_KERNELS_INTEGRAL_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace migraphx {
namespace math {
constexpr float as_float(migraphx::half x) { return x; }
template <class T>
constexpr T as_float(T x)
{
return x;
}
} // namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH(abs, ::abs)
MIGRAPHX_DEVICE_MATH(acos, ::acos)
MIGRAPHX_DEVICE_MATH(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH(asin, ::asin)
MIGRAPHX_DEVICE_MATH(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH(atan, ::atan)
MIGRAPHX_DEVICE_MATH(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH(cos, ::cos)
MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin)
MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
MIGRAPHX_DEVICE_MATH_FOR(float, acosh, ::acoshf)
MIGRAPHX_DEVICE_MATH_FOR(float, asin, ::asinf)
MIGRAPHX_DEVICE_MATH_FOR(float, asinh, ::asinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, atan, ::atanf)
MIGRAPHX_DEVICE_MATH_FOR(float, atanh, ::atanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, cos, ::cosf)
MIGRAPHX_DEVICE_MATH_FOR(float, cosh, ::coshf)
MIGRAPHX_DEVICE_MATH_FOR(float, rsqrt, ::rsqrtf)
MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF(acos, ::acos)
MIGRAPHX_DEVICE_MATH_HALF(acosh, ::acosh)
MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
{
return cond ? a : b;
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
MIGRAPHX_DEVICE_MATH_VEC(asin)
MIGRAPHX_DEVICE_MATH_VEC(asinh)
MIGRAPHX_DEVICE_MATH_VEC(atan)
MIGRAPHX_DEVICE_MATH_VEC(atanh)
MIGRAPHX_DEVICE_MATH_VEC(ceil)
MIGRAPHX_DEVICE_MATH_VEC(cos)
MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin)
MIGRAPHX_DEVICE_MATH_VEC(sinh)
MIGRAPHX_DEVICE_MATH_VEC(sqrt)
MIGRAPHX_DEVICE_MATH_VEC(tan)
MIGRAPHX_DEVICE_MATH_VEC(tanh)
MIGRAPHX_DEVICE_MATH_VEC(where)
template <class T, class U>
constexpr auto max(const T& a, const U& b)
{
return where(a < b, b, a);
}
template <class T, class U>
constexpr auto min(const T& a, const U& b)
{
return where(a > b, b, a);
}
template <class T, class U>
constexpr auto convert(U v)
{
return vec_transform(v)([](auto x) -> T { return x; });
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
...@@ -3,19 +3,45 @@ ...@@ -3,19 +3,45 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp> #include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp> #include <migraphx/kernels/args.hpp>
namespace migraphx { namespace migraphx {
template <class T>
struct implicit_conversion_op
{
T x;
template <index_int N, class U>
constexpr operator vec<U, N>() const
{
static_assert(vec_size<T>() == N, "Vector mismatch size");
return __builtin_convertvector(x, vec<U, N>);
}
template <class U>
constexpr operator U() const
{
return x;
}
};
template <class T>
constexpr implicit_conversion_op<T> implicit_conversion(T x)
{
return {x};
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
{ {
preload<typename T::type>(idx, xs...)([&](auto... ps) { preload<typename T::type>(idx, xs...)([&](auto... ps) {
idx.global_stride(out.get_shape().elements(), [&](auto i) { idx.global_stride(out.get_shape().elements(), [&](auto i) {
auto multi_idx = out.get_shape().multi(i); auto multi_idx = out.get_shape().multi(i);
out[multi_idx] = f(ps[multi_idx]...); out[multi_idx] = implicit_conversion(f(ps[multi_idx]...));
}); });
}); });
} }
...@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs) ...@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template <class F, class... Ts> template <class F, class... Ts>
__device__ void pointwise(F f, Ts*... ps) __device__ void pointwise(F f, Ts*... ps)
{ {
auto t = transform_args(make_tensors(), rotate_last()); auto t = transform_args(make_tensors(), rotate_last(), auto_vectorize());
t(ps...)([&](auto... xs) { t(ps...)([&](auto... xs) {
auto idx = make_index(); auto idx = make_index();
pointwise_tensor(idx, f, xs...); pointwise_tensor(idx, f, xs...);
......
...@@ -14,9 +14,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -14,9 +14,7 @@ constexpr auto traverse_preload(Shapes... ss)
auto each = [&](auto x) { auto each = [&](auto x) {
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted()) if constexpr(not s.broadcasted() or (s.elements() - size) < 64)
return f(x, offset, false_type{});
else if constexpr((s.elements() - size) < 64)
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
...@@ -31,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss) ...@@ -31,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
} }
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr index_int compute_preload_size(Shapes...) constexpr index_int compute_preload_size_c(Shapes...)
{ {
index_int size = 0; index_int size = 0;
traverse_preload<T>(Shapes{}...)( traverse_preload<T>(Shapes{}...)(
...@@ -39,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...) ...@@ -39,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return size; return size;
} }
template <class T, class... Shapes>
constexpr auto compute_preload_size(Shapes...)
{
return _c<compute_preload_size_c<T>(Shapes{}...)>;
}
template <class F, class T, class... Ts> template <class F, class T, class... Ts>
__device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
{ {
...@@ -50,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -50,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[&](auto x, auto offset, auto copy) { [&](auto x, auto offset, auto copy) {
if constexpr(copy) if constexpr(copy)
{ {
auto v = vectorize(x); if constexpr(decltype(tensor_vec_size(x)){} == 0)
auto b = as_vec(tensor_vec_size(v), buffer + offset); {
idx.local_stride(v.get_shape().element_space(), auto v = vectorize(x);
[&](auto i) { b[i] = v.data()[i]; }); auto b = as_vec(tensor_vec_size(v), buffer + offset);
return x.with(buffer + offset); idx.local_stride(v.get_shape().element_space(),
[&](auto i) { b[i] = v.data()[i]; });
return x.with(buffer + offset);
}
else
{
auto b = as_vec(tensor_vec_size(x), buffer + offset);
idx.local_stride(x.get_shape().element_space(),
[&](auto i) { b[i] = x.data()[i]; });
return x.with(b);
}
} }
else else
{ {
...@@ -80,7 +94,7 @@ template <class T, class... Ts> ...@@ -80,7 +94,7 @@ template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = typename remove_vec<T>::type;
constexpr auto size = compute_preload_size<type>(xs.get_shape()...); constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
#ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP #ifndef MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#define MIGRAPHX_GUARD_KERNELS_PRINT_HPP #define MIGRAPHX_GUARD_KERNELS_PRINT_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/algorithm.hpp> #include <migraphx/kernels/algorithm.hpp>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp> #include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp> #include <migraphx/kernels/basic_ops.hpp>
#include <args.hpp> #include <migraphx/kernels/array.hpp>
namespace migraphx { namespace migraphx {
...@@ -104,14 +104,24 @@ MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data, ...@@ -104,14 +104,24 @@ MIGRAPHX_DEVICE_CONSTEXPR T calc_pooling(const T*& data,
return op.final(output_val, count); return op.final(output_val, count);
} }
template <class T, class U, class V, class W> template <class T1, class T2, class T3, class T4>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t) struct roalign_settings
{ {
const float roi_offset = ROIS_OFFSET; T1 roi_offset{};
const bool is_avg_pooling = IS_AVG_POOLING; T2 is_avg_pooling{};
const int64_t sampling_ratio = SAMPLING_RATIO; T3 sampling_ratio{};
const float spatial_scale = SPATIAL_SCALE; T4 spatial_scale{};
};
template <class... Ts>
constexpr roalign_settings<Ts...> make_roalign_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class V, class W, class Settings>
__device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& y_t, Settings s)
{
auto index = make_index(); auto index = make_index();
const auto* x = x_t.data(); const auto* x = x_t.data();
const auto* rois = rois_t.data(); const auto* rois = rois_t.data();
...@@ -146,9 +156,10 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -146,9 +156,10 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
const auto* offset_rois = rois + (n * roi_column_num); const auto* offset_rois = rois + (n * roi_column_num);
const int batch_ind = ind[n]; const int batch_ind = ind[n];
array<float, 2> roi_starts = {offset_rois[1] * spatial_scale, array<float, 2> roi_starts = {offset_rois[1] * s.spatial_scale,
offset_rois[0] * spatial_scale}; offset_rois[0] * s.spatial_scale};
array<float, 2> roi_ends = {offset_rois[3] * spatial_scale, offset_rois[2] * spatial_scale}; array<float, 2> roi_ends = {offset_rois[3] * s.spatial_scale,
offset_rois[2] * s.spatial_scale};
array<float, 2> roi_size{}; array<float, 2> roi_size{};
array<float, 2> bin_size{}; array<float, 2> bin_size{};
...@@ -161,11 +172,11 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -161,11 +172,11 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
bin_size[ii] = roi_size[ii] / out_dims[ii]; bin_size[ii] = roi_size[ii] / out_dims[ii];
bin_grid_size[ii] = bin_grid_size[ii] =
(sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]); (s.sampling_ratio > 0) ? s.sampling_ratio : std::ceil(roi_size[ii] / out_dims[ii]);
} }
const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]); const auto* offset_x = x + ((batch_ind * channel_num + c) * in_dims[0] * in_dims[1]);
if constexpr(is_avg_pooling) if constexpr(s.is_avg_pooling)
{ {
out_ptr[i] = calc_pooling(offset_x, out_ptr[i] = calc_pooling(offset_x,
roi_starts, roi_starts,
...@@ -173,7 +184,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -173,7 +184,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
roi_offset, s.roi_offset,
avg_pool{}); avg_pool{});
} }
else else
...@@ -184,7 +195,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W& ...@@ -184,7 +195,7 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, const W&
{ph, pw}, {ph, pw},
bin_grid_size, bin_grid_size,
in_dims, in_dims,
roi_offset, s.roi_offset,
max_pool{}); max_pool{});
} }
} }
......
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_TYPES_HPP
#include <hip/hip_runtime.h> #include <migraphx/kernels/hip.hpp>
namespace migraphx { namespace migraphx {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment