Commit 8623590a authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into tf-transpose_ss

parents 0a8dc926 bc80dee8
...@@ -269,6 +269,12 @@ struct match_fold_f ...@@ -269,6 +269,12 @@ struct match_fold_f
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...); return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
} }
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
...@@ -283,18 +289,14 @@ struct match_fold_f ...@@ -283,18 +289,14 @@ struct match_fold_f
template <class Selector> template <class Selector>
auto operator[](Selector select) const auto operator[](Selector select) const
{ {
return [=](auto... mms) { return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object // Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(mms...); auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op; Op op;
bool matches = Start; bool matches = Start;
select(start, [&](auto ins) { select(start, [&](auto ins) {
auto fm = [&] { auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
return mpack([&](auto... ms) {
return match_fold_f::fold_matchers(ctx, ins, ms...);
});
};
matches = op(always(matches), fm); matches = op(always(matches), fm);
}); });
if(matches == Matches) if(matches == Matches)
...@@ -328,6 +330,10 @@ inline auto outputs() ...@@ -328,6 +330,10 @@ inline auto outputs()
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; } MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; } MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins) MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{ {
return ins->get_shape().broadcasted(); return ins->get_shape().broadcasted();
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmax
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{
max_val = cur_val;
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmin
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{
min_val = cur_val;
min_index = i;
}
}
return min_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
......
...@@ -63,6 +63,8 @@ struct onnx_parser ...@@ -63,6 +63,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
...@@ -93,6 +95,7 @@ struct onnx_parser ...@@ -93,6 +95,7 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru); add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm); add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad); add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
// init the activation function map // init the activation function map
init_actv_func(); init_actv_func();
...@@ -274,6 +277,60 @@ struct onnx_parser ...@@ -274,6 +277,60 @@ struct onnx_parser
return prog.add_instruction(op::logsoftmax{axis}, std::move(args)); return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -1230,6 +1287,40 @@ struct onnx_parser ...@@ -1230,6 +1287,40 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output}; return {hidden_states, last_output, last_cell_output};
} }
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
}
}
void parse_from(std::istream& is) void parse_from(std::istream& is)
{ {
onnx::ModelProto model; onnx::ModelProto model;
......
...@@ -91,7 +91,7 @@ struct find_reshaper ...@@ -91,7 +91,7 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
...@@ -132,7 +132,7 @@ struct find_nop_reshapes ...@@ -132,7 +132,7 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); p.replace_instruction(ins, ins->inputs().front());
...@@ -147,7 +147,7 @@ struct find_transpose ...@@ -147,7 +147,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -181,7 +181,7 @@ struct find_concat_transpose ...@@ -181,7 +181,7 @@ struct find_concat_transpose
match::all_of[match::inputs()](match::transpose_shape())); match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(program& p, match::matcher_result mr) const void apply(program& p, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto s = ins->inputs().front()->get_shape(); auto s = ins->inputs().front()->get_shape();
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <migraphx/op/pad.hpp> #include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp> #include <migraphx/op/softmax.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
......
...@@ -12,6 +12,8 @@ endif() ...@@ -12,6 +12,8 @@ endif()
add_library(migraphx_device add_library(migraphx_device
device/add.cpp device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp device/max.cpp
device/min.cpp device/min.cpp
device/exp.cpp device/exp.cpp
...@@ -44,6 +46,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR ...@@ -44,6 +46,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(migraphx_gpu add_library(migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
fuse_ops.cpp fuse_ops.cpp
hip.cpp hip.cpp
......
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmax_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/arg_op.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
arg_op(argmin_op{}, stream, result, arg, axis);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x) ...@@ -128,7 +128,7 @@ __device__ T dpp_mov(T& x)
template <class T, class Op> template <class T, class Op>
__device__ void dpp_reduce(T& in, Op op) __device__ void dpp_reduce(T& in, Op op)
{ {
T out; T out{};
out = dpp_mov<dpp_row_shr(1)>(in); out = dpp_mov<dpp_row_shr(1)>(in);
in = op(in, out); in = op(in, out);
out = dpp_mov<dpp_row_shr(2)>(in); out = dpp_mov<dpp_row_shr(2)>(in);
......
...@@ -200,12 +200,33 @@ struct hip_add_relu ...@@ -200,12 +200,33 @@ struct hip_add_relu
} }
}; };
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
// Ensure the last arguments is the broadcasted one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().broadcasted(); });
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
}
void move_standard_front(std::vector<instruction_ref>& args)
{
// Ensure the first arguments is the standard one
auto it = std::find_if(
args.begin(), args.end(), [](auto arg) { return arg->get_shape().standard(); });
if(it != args.end())
std::swap(*it, args.front());
}
struct find_add_relu struct find_add_relu
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::relu")(match::arg(0)( return match::name("gpu::relu")(
match::any_of(match::name("gpu::add"), match::name("hip::triadd")).bind("add"))); match::arg(0)(match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of[match::inputs()](match::standard_shape()))
.bind("add")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -213,6 +234,9 @@ struct find_add_relu ...@@ -213,6 +234,9 @@ struct find_add_relu
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
auto ins = r.result; auto ins = r.result;
auto args = add_ins->inputs(); auto args = add_ins->inputs();
move_standard_front(args);
move_broadcasted_back(args);
// Use the allocation from the relu operator // Use the allocation from the relu operator
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
if(add_ins->name() == "gpu::add") if(add_ins->name() == "gpu::add")
...@@ -226,8 +250,9 @@ struct find_triadd ...@@ -226,8 +250,9 @@ struct find_triadd
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::add")(match::either_arg(0, 1)(match::name("gpu::add").bind("add"), return match::name("gpu::add")(match::either_arg(0, 1)(
match::any().bind("input"))); match::name("gpu::add").bind("add"),
match::any(match::any_of[match::inputs()](match::standard_shape())).bind("input")));
} }
void apply(program& p, match::matcher_result r) const void apply(program& p, match::matcher_result r) const
...@@ -242,10 +267,9 @@ struct find_triadd ...@@ -242,10 +267,9 @@ struct find_triadd
if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1) if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
return; return;
args.insert(args.begin(), input_ins); args.insert(args.begin(), input_ins);
// Ensure the last arguments is the broadcasted one move_standard_front(args);
auto it = std::find_if(args.begin(), args.end(), is_broadcasted); move_broadcasted_back(args);
if(it != args.end())
std::swap(*it, *std::prev(args.end(), 2));
args.back() = ins->inputs().back(); args.back() = ins->inputs().back();
p.replace_instruction(ins, hip_triadd{}, args); p.replace_instruction(ins, hip_triadd{}, args);
} }
...@@ -404,8 +428,8 @@ void fuse_ops::apply(program& p) const ...@@ -404,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off // clang-format off
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
// find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
// find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_relu{} find_add_relu{}
); );
// clang-format on // clang-format on
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARG_OP_HPP
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
struct val_index
{
T val;
int64_t index;
};
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val > y.val)
return x;
else if(x.val < y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
};
struct argmin_op
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
return (x.index < y.index) ? x : y;
}
}
MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
};
template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{
auto arg_shape = arg.get_shape();
auto lens = arg_shape.lens();
auto batch_lens = lens;
size_t batch_item_num = lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
auto init = make_val_index<type>(op.init());
auto op_output =
block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return make_val_index(input[arg_s.index(data_idx)], j);
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = op_output.index;
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ARGMAX_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment