Commit e1efa548 authored by jerryyin's avatar jerryyin
Browse files

Addressing review feedbacks

parent e7f93b38
...@@ -177,10 +177,9 @@ std::string cpp_generator::generate_point_op(const operation& op, ...@@ -177,10 +177,9 @@ std::string cpp_generator::generate_point_op(const operation& op,
// For an optional argument where i >= args.size(), treat // For an optional argument where i >= args.size(), treat
// the optional argument as a straight zero. This will // the optional argument as a straight zero. This will
// cacel out the optional bias, if it exists. // cacel out the optional bias, if it exists.
if(i < args.size()) if(i >= args.size())
return args.at(i); MIGRAPHX_THROW("Invalid argument index: " + key);
else return args.at(i);
return "0";
} }
else if(v.contains(key)) else if(v.contains(key))
{ {
......
...@@ -39,8 +39,10 @@ struct dequantizelinear ...@@ -39,8 +39,10 @@ struct dequantizelinear
{ {
value attributes() const { value attributes() const {
return {{"pointwise", true}, {"point_op", // Note: point_op attribute is not used in this op. Instead, in
"${1} * (${function:convert}<float>(${0}) - ${function:convert}<float>(${2}))"}}; // gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}, {"point_op", ""}};
} }
std::string name() const { return "dequantizelinear"; } std::string name() const { return "dequantizelinear"; }
......
...@@ -39,8 +39,12 @@ struct quantizelinear ...@@ -39,8 +39,12 @@ struct quantizelinear
{ {
std::string name() const { return "quantizelinear"; } std::string name() const { return "quantizelinear"; }
value attributes() const { return {{"pointwise", true}, {"point_op", value attributes() const {
"${function:max}(${function:min}(${function:round}(${function:convert}<float>(${0}) / ${1}) + ${function:convert}<float>(${2}), 127.0), -128.0)"}}; } // Note: point_op attribute is not used in this op. Instead, in
// gpu compilation pipeline, rewrite_quantization will be invoked
// from generate_pointwise() to rewrite this op.
return {{"pointwise", true}, {"point_op", ""}};
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment