"...resnet50_tensorflow.git" did not exist on "92484dfec686658fd62913f646e86f72006a5cd2"
Unverified Commit 651ea160 authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Refactor to use tune_axis function (#713)

* initial testing

* initial testing

* add dequantize

* formatting

* add tests

* formatting

* revert file

* add parse files

* formatting

* add axis tuning and fix tests

* formatting

* add tests and fix int8

* formatting

* fix tidy

* test with int32

* add default name and change string to upper

* formatting

* remove boost call

* refactor to use tune_axis)

* formatting
parent 4d28180c
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -33,10 +34,9 @@ void eliminate_concat::apply(module& p) const ...@@ -33,10 +34,9 @@ void eliminate_concat::apply(module& p) const
// axis OR the sizes to the left of this axis are all equal to 1 // axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical // Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input // we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens(); auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator()); auto concat_op = concat_opt.get_concat(ins->get_operator());
std::size_t axis_index = std::size_t axis_index = tune_axis(lens.size(), concat_op.axis, concat_op.name());
(concat_op.axis < 0) ? (concat_op.axis + lens.size()) : concat_op.axis;
if(axis_index == 0 || if(axis_index == 0 ||
std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; })) std::all_of(lens.begin(), lens.begin() + axis_index, [](auto x) { return x == 1; }))
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,12 +38,8 @@ struct argmax ...@@ -37,12 +38,8 @@ struct argmax
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis; int64_t tuned_axis = tune_axis(n_dim, axis, name());
lens[tuned_axis] = 1; lens[tuned_axis] = 1;
...@@ -75,7 +72,7 @@ struct argmax ...@@ -75,7 +72,7 @@ struct argmax
{ {
argument result{output_shape}; argument result{output_shape};
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis; auto tuned_axis = tune_axis(n_dim, axis, name());
auto batch_item_num = args.front().get_shape().lens()[tuned_axis]; auto batch_item_num = args.front().get_shape().lens()[tuned_axis];
result.visit([&](auto output) { result.visit([&](auto output) {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,12 +38,8 @@ struct argmin ...@@ -37,12 +38,8 @@ struct argmin
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size()); int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
int64_t tuned_axis = (axis < 0) ? axis + n_dim : axis; int64_t tuned_axis = tune_axis(n_dim, axis, name());
lens[tuned_axis] = 1; lens[tuned_axis] = 1;
return {shape::int64_type, lens}; return {shape::int64_type, lens};
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,7 +40,7 @@ struct concat ...@@ -39,7 +40,7 @@ struct concat
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto n_dims = args[0].get_shape().lens().size(); auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = (axis < 0) ? axis + n_dims : axis; std::size_t axis_index = tune_axis(n_dims, axis, name());
std::vector<std::size_t> offsets; std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(n_dims, 0); std::vector<std::size_t> offset(n_dims, 0);
offset[axis_index] = 0; offset[axis_index] = 0;
......
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR") inline int tune_axis(const int n_dim, const int axis, const std::string& op_name = "OPERATOR")
{ {
if(axis >= n_dim || abs(axis) > n_dim) if(axis >= n_dim || std::abs(axis) > n_dim)
{ {
MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range."); MIGRAPHX_THROW(to_upper(op_name) + ": axis is out of range.");
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,7 +12,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements> ...@@ -11,7 +12,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
{ {
std::vector<op_desc> operators() const { return {{"GatherElements"}}; } std::vector<op_desc> operators() const { return {{"GatherElements"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -35,7 +36,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements> ...@@ -35,7 +36,7 @@ struct parse_gather_elements : op_parser<parse_gather_elements>
} }
int n_rank = static_cast<int>(data_s.lens().size()); int n_rank = static_cast<int>(data_s.lens().size());
int tuned_axis = (axis < 0) ? (axis + n_rank) : axis; int tuned_axis = tune_axis(n_rank, axis, opd.op_name);
auto axis_stride = data_s.strides()[tuned_axis]; auto axis_stride = data_s.strides()[tuned_axis];
int64_t data_elem_num = static_cast<int64_t>(data_s.elements()); int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,7 +13,7 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -12,7 +13,7 @@ struct parse_onehot : op_parser<parse_onehot>
{ {
std::vector<op_desc> operators() const { return {{"OneHot"}}; } std::vector<op_desc> operators() const { return {{"OneHot"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -39,12 +40,8 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -39,12 +40,8 @@ struct parse_onehot : op_parser<parse_onehot>
auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]}); auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim // Finally, we need a transpose to move the inner most dim to the axis dim
int n_rank = gather_out->get_shape().lens().size(); int n_rank = gather_out->get_shape().lens().size();
if(axis < -n_rank or axis >= n_rank) int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
{
MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> perm(n_rank - 1); std::vector<int64_t> perm(n_rank - 1);
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
perm.insert(perm.begin() + tuned_axis, n_rank - 1); perm.insert(perm.begin() + tuned_axis, n_rank - 1);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -11,7 +12,7 @@ struct parse_split : op_parser<parse_split> ...@@ -11,7 +12,7 @@ struct parse_split : op_parser<parse_split>
{ {
std::vector<op_desc> operators() const { return {{"Split"}}; } std::vector<op_desc> operators() const { return {{"Split"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/, std::vector<instruction_ref> parse(const op_desc& opd,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -22,13 +23,9 @@ struct parse_split : op_parser<parse_split> ...@@ -22,13 +23,9 @@ struct parse_split : op_parser<parse_split>
axis = parser.parse_value(info.attributes.at("axis")).at<int>(); axis = parser.parse_value(info.attributes.at("axis")).at<int>();
} }
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
int64_t n_rank = static_cast<int64_t>(lens.size()); int64_t n_rank = static_cast<int64_t>(lens.size());
if((axis < -n_rank) || (axis >= n_rank)) int64_t tuned_axis = tune_axis(n_rank, axis, opd.op_name);
{
MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
}
int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
std::vector<int64_t> vec_splits; std::vector<int64_t> vec_splits;
if(contains(info.attributes, "split")) if(contains(info.attributes, "split"))
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <unordered_set> #include <unordered_set>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <map> #include <map>
...@@ -251,7 +252,7 @@ struct find_concat_transpose ...@@ -251,7 +252,7 @@ struct find_concat_transpose
// axis could be a negative value // axis could be a negative value
int64_t n_dim = static_cast<int64_t>(s.lens().size()); int64_t n_dim = static_cast<int64_t>(s.lens().size());
op.axis = (op.axis < 0) ? (op.axis + n_dim) : op.axis; op.axis = tune_axis(n_dim, op.axis, op.name());
auto ipermutation = invert_permutation(permutation); auto ipermutation = invert_permutation(permutation);
op.axis = ipermutation[op.axis]; op.axis = ipermutation[op.axis];
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/tune_axis.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
...@@ -407,9 +408,9 @@ struct cpu_softmax : auto_register_op<cpu_softmax<Op>> ...@@ -407,9 +408,9 @@ struct cpu_softmax : auto_register_op<cpu_softmax<Op>>
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis; int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis]; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1; batch_lens[tuned_axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
......
#include <migraphx/gpu/argmax.hpp> #include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp> #include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -15,7 +16,7 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -15,7 +16,7 @@ shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis; int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
#include <migraphx/gpu/argmin.hpp> #include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp> #include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -15,7 +16,7 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const ...@@ -15,7 +16,7 @@ shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
int64_t tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis; int64_t tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::argmin(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/op/logsoftmax.hpp> #include <migraphx/op/logsoftmax.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/miopen.hpp>
#include <migraphx/tune_axis.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -19,7 +20,7 @@ argument ...@@ -19,7 +20,7 @@ argument
hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const hip_logsoftmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis; auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::logsoftmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
#include <migraphx/gpu/softmax.hpp> #include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp> #include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -15,7 +16,7 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const ...@@ -15,7 +16,7 @@ shape hip_softmax::compute_shape(const std::vector<shape>& inputs) const
argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_softmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = (op.axis < 0) ? op.axis + n_dim : op.axis; auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis); device::softmax(ctx.get_stream().get(), args.back(), args.front(), tuned_axis);
return args.back(); return args.back();
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/ref/gemm.hpp> #include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <iostream> #include <iostream>
...@@ -791,9 +792,9 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>> ...@@ -791,9 +792,9 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = output_shape.lens();
int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis; int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis]; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1; batch_lens[tuned_axis] = 1;
shape batch_shape{shape::int32_type, batch_lens}; shape batch_shape{shape::int32_type, batch_lens};
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/pad_calc.hpp> #include <migraphx/pad_calc.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -876,10 +877,8 @@ struct tf_parser ...@@ -876,10 +877,8 @@ struct tf_parser
{ {
axis = static_cast<int>(attributes.at("axis").i()); axis = static_cast<int>(attributes.at("axis").i());
} }
if(axis < 0)
{ axis = tune_axis(num_dims, axis, "tf_parse_softmax");
axis += num_dims;
}
return mm->add_instruction(Op{axis}, make_contiguous(args[0])); return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment