Unverified Commit 549cfe72 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Normalize compute methods (#723)



* Normalize compute functions

* Formatting

* Save normalization flag to the file

* Formatting

* Remove tuned functions

* Formatting

* Use in_index
Co-authored-by: default avatarShucai Xiao <shucai@gmail.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e4bc095e
......@@ -80,6 +80,13 @@ struct instruction
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
void set_normalized(bool value = true);
bool is_normalized() const;
bool need_normalization() const;
operation normalized_operator() const;
void debug_print() const;
private:
......@@ -99,6 +106,7 @@ struct instruction
std::vector<instruction_ref> output;
std::vector<instruction_ref> arguments;
literal lit;
bool normalized = false;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -36,28 +36,22 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
auto lens = inputs[0].lens();
int64_t tuned_axis = tune_axis(n_dim, axis, name());
lens[tuned_axis] = 1;
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmax(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end());
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{
max_val = cur_val;
......@@ -71,15 +65,13 @@ struct argmax
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, axis, name());
auto batch_item_num = args.front().get_shape().lens()[tuned_axis];
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, tuned_axis, data_idx, batch_item_num);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
......
......@@ -36,27 +36,22 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
auto lens = inputs[0].lens();
int64_t tuned_axis = tune_axis(n_dim, axis, name());
lens[tuned_axis] = 1;
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmin(T& input,
int64_t tuned_axis,
std::vector<std::size_t>& indices,
size_t item_num) const
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[tuned_axis] = i;
auto cur_val = input(indices.begin(), indices.end());
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{
min_val = cur_val;
......@@ -70,15 +65,13 @@ struct argmin
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = axis < 0 ? axis + n_dim : axis;
std::size_t batch_item_num = args.front().get_shape().lens()[tuned_axis];
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, tuned_axis, data_idx, batch_item_num);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
......
......@@ -39,15 +39,14 @@ struct concat
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument>& args) const
{
auto n_dims = args[0].get_shape().lens().size();
std::size_t axis_index = tune_axis(n_dims, axis, name());
auto n_dims = args[0].get_shape().lens().size();
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(n_dims, 0);
offset[axis_index] = 0;
offset[axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis_index] += arg.get_shape().lens()[axis_index];
offset[axis] += arg.get_shape().lens()[axis];
}
return offsets;
}
......
......@@ -70,7 +70,7 @@ struct gather
{
auto in_index = indices.front();
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
output[0] = data[indices.front()];
output[0] = data[in_index];
}
else
{
......
......@@ -46,54 +46,6 @@ struct slice
std::string name() const { return "slice"; }
void tune_attributes(std::vector<int64_t>& tuned_axes,
std::vector<int64_t>& tuned_starts,
std::vector<int64_t>& tuned_ends,
const std::vector<std::size_t>& lens) const
{
// tune axes
int64_t n_rank = static_cast<int64_t>(lens.size());
if(!std::all_of(tuned_axes.begin(), tuned_axes.end(), [=](auto i) {
return (i < n_rank and i >= -n_rank);
}))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(tuned_axes) + " out of range");
}
std::transform(tuned_axes.begin(), tuned_axes.end(), tuned_axes.begin(), [=](auto i) {
return (i < 0) ? (i + n_rank) : i;
});
std::vector<int64_t> axis_lens(tuned_axes.size());
std::transform(tuned_axes.begin(), tuned_axes.end(), axis_lens.begin(), [&](auto axis) {
return lens[axis];
});
// tune starts
std::transform(tuned_starts.begin(),
tuned_starts.end(),
axis_lens.begin(),
tuned_starts.begin(),
[=](auto i, auto dim) {
i = (i < -dim) ? -dim : ((i > dim) ? dim : i);
return (i < 0) ? (i + dim) : i;
});
// tune ends
std::transform(tuned_ends.begin(),
tuned_ends.end(),
axis_lens.begin(),
tuned_ends.begin(),
[=](auto i, auto dim) {
i = (i < -dim) ? -dim : ((i > dim) ? dim : i);
return (i < 0) ? (i + dim) : i;
});
if(!(tuned_ends >= tuned_starts))
{
MIGRAPHX_THROW("SLICE: starts and ends does not match");
}
}
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
......@@ -104,27 +56,22 @@ struct slice
auto compute_offset(const shape& s) const
{
std::vector<int64_t> tuned_axes = axes;
std::vector<int64_t> tuned_starts = starts;
std::vector<int64_t> tuned_ends = ends;
const std::vector<std::size_t>& lens = s.lens();
tune_attributes(tuned_axes, tuned_starts, tuned_ends, lens);
const std::vector<std::size_t>& lens = s.lens();
const std::vector<std::size_t>& strides = s.strides();
auto offset = 0;
if(!tuned_axes.empty())
if(!axes.empty())
{
for(std::size_t i = 0; i < tuned_axes.size(); i++)
for(std::size_t i = 0; i < axes.size(); i++)
{
auto axis = tuned_axes[i];
offset += fix_index(lens, axis, tuned_starts[i]) * strides[axis];
auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis];
}
}
else
{
for(std::size_t axis = 0; axis < lens.size(); axis++)
{
offset += fix_index(lens, axis, tuned_starts[axis]) * strides[axis];
offset += fix_index(lens, axis, starts[axis]) * strides[axis];
}
}
return offset;
......
......@@ -33,7 +33,8 @@ void instruction::replace(const shape& r)
void instruction::replace(operation o)
{
op = std::move(o);
normalized = false;
op = std::move(o);
recompute_shape();
}
......@@ -158,7 +159,8 @@ void instruction::replace(instruction_ref ins,
void instruction::replace(operation o, const shape& r, std::vector<instruction_ref> args)
{
op = std::move(o);
normalized = false;
op = std::move(o);
replace(r);
replace(std::move(args));
}
......@@ -208,7 +210,7 @@ argument instruction::eval(bool check_eval) const
this->inputs().end(),
std::back_inserter(args),
[](auto arg) { return arg->eval(false); });
return op.compute(result, args);
return normalized_operator().compute(result, args);
}
return {};
}
......@@ -260,6 +262,27 @@ instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
return get_output_alias(ins->inputs().at(i));
}
void instruction::set_normalized(bool value) { normalized = value; }
bool instruction::is_normalized() const { return normalized; }
bool instruction::need_normalization() const
{
return this->get_operator().need_normalization() and not normalized;
}
operation instruction::normalized_operator() const
{
operation o = this->get_operator();
if(this->need_normalization())
{
auto lens = this->inputs().front()->get_shape().lens();
if(!normalize_attributes(o, lens))
return this->get_operator();
}
return o;
}
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
......
......@@ -389,6 +389,11 @@ void module::finalize(context& ctx)
{
ins->finalize(ctx);
}
// Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
if(ins != end())
std::cerr << "WARNING: Instruction needs normalization, performance may be affected."
<< std::endl;
}
value module::to_value() const
......@@ -397,9 +402,10 @@ value module::to_value() const
value nodes;
this->print([&](auto ins, const auto& names) {
value node;
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["output"] = names.at(ins);
node["name"] = ins->name();
node["shape"] = migraphx::to_value(ins->get_shape());
node["normalized"] = ins->is_normalized();
if(ins->name() == "@literal")
node["literal"] = migraphx::to_value(ins->get_literal());
node["operator"] = ins->get_operator().to_value();
......@@ -421,8 +427,9 @@ void module::from_value(const value& v)
for(const value& node : v.at("nodes"))
{
instruction_ref output;
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto name = node.at("name").to<std::string>();
auto fields = node.at("operator");
auto normalized = node.at("normalized").to<bool>();
if(name == "@param")
{
output = this->add_parameter(fields["parameter"].to<std::string>(),
......@@ -445,6 +452,7 @@ void module::from_value(const value& v)
else
output = this->add_instruction(op, inputs);
}
output->set_normalized(normalized);
instructions[node.at("output").to<std::string>()] = output;
}
}
......
......@@ -25,6 +25,7 @@ void normalize_ops::apply(module& m) const
if(normalize_attributes(tuned_op, lens))
{
m.replace_instruction(ins, tuned_op, inputs);
ins->set_normalized();
}
}
}
......
......@@ -228,7 +228,8 @@ std::vector<argument> generic_eval(const module& p,
return results[i];
});
results.emplace(ins, trace(ins, [&] {
return ins->get_operator().compute(ctx, ins->get_shape(), values);
return ins->normalized_operator().compute(
ctx, ins->get_shape(), values);
}));
}
assert(results.find(ins) != results.end());
......@@ -284,7 +285,7 @@ std::vector<argument> program::eval(parameter_map params) const
}
}
const int program_file_version = 3;
const int program_file_version = 4;
value program::to_value() const
{
......
......@@ -87,6 +87,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
check_context<context>{},
normalize_ops{},
dead_code_elimination{},
eliminate_identity{}
};
......
......@@ -21,7 +21,7 @@ void expect_shape(const migraphx::shape& expected, const migraphx::operation& op
mm->add_instruction(op, args);
if(p.get_output_shapes().back() != expected)
{
std::cout << "FAILED: Incorrect shape for " << op.name() << ": ";
std::cout << "FAILED: Incorrect shape for " << op << ": ";
std::cout << expected << " != " << p.get_output_shapes().back() << std::endl;
for(auto&& s : shapes)
std::cout << " " << s << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment