"composable_kernel/include/utility/sequence.hpp" did not exist on "df73287b820c5eb801480a1e6b957b8c717d35b8"
Commit 7f65a88e authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into mlir-c

parents 79bfe69f b20e3d4d
...@@ -19,7 +19,7 @@ jobs: ...@@ -19,7 +19,7 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.8 - uses: satackey/action-docker-layer-caching@v0.0.11
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -65,7 +65,7 @@ jobs: ...@@ -65,7 +65,7 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.8 - uses: satackey/action-docker-layer-caching@v0.0.11
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -108,7 +108,7 @@ jobs: ...@@ -108,7 +108,7 @@ jobs:
# In this step, this action saves a list of existing images, # In this step, this action saves a list of existing images,
# the cache is created without them in the post run. # the cache is created without them in the post run.
# It also restores the cache if it exists. # It also restores the cache if it exists.
- uses: satackey/action-docker-layer-caching@v0.0.8 - uses: satackey/action-docker-layer-caching@v0.0.11
# Ignore the failure of a step and avoid terminating the job. # Ignore the failure of a step and avoid terminating the job.
continue-on-error: true continue-on-error: true
...@@ -142,10 +142,12 @@ jobs: ...@@ -142,10 +142,12 @@ jobs:
with: with:
python-version: 3.6 python-version: 3.6
- name: Install pyflakes - name: Install pyflakes
run: pip install pyflakes==2.3.1 run: pip install pyflakes==2.3.1 mypy==0.931
- name: Run pyflakes - name: Run pyflakes
run: pyflakes examples/ tools/ src/ test/ doc/ run: |
pyflakes examples/ tools/ src/ test/ doc/
mypy tools/api.py
linux: linux:
......
...@@ -200,6 +200,7 @@ rocm_enable_cppcheck( ...@@ -200,6 +200,7 @@ rocm_enable_cppcheck(
RULE_FILE RULE_FILE
${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules ${CMAKE_CURRENT_SOURCE_DIR}/cppcheck.rules
SOURCES SOURCES
examples/
src/ src/
test/ test/
INCLUDE INCLUDE
......
function(eval_and_strip_genex OUTPUT_VAR INPUT)
string(REPLACE "$<LINK_LANGUAGE:CXX>" "1" INPUT "${INPUT}")
string(REPLACE "$<COMPILE_LANGUAGE:CXX>" "1" INPUT "${INPUT}")
string(REPLACE "SHELL:" "" INPUT "${INPUT}")
string(REPLACE "$<BOOL:>" "0" INPUT "${INPUT}")
string(REGEX REPLACE "\\$<BOOL:(0|FALSE|false|OFF|off|N|n|IGNORE|ignore|NOTFOUND|notfound)>" "0" INPUT "${INPUT}")
string(REGEX REPLACE "\\$<BOOL:[^<>]*-NOTFOUND>" "0" INPUT "${INPUT}")
string(REGEX REPLACE "\\$<BOOL:[^$<>]*>" "1" INPUT "${INPUT}")
string(REPLACE "$<NOT:0>" "1" INPUT "${INPUT}")
string(REPLACE "$<NOT:1>" "0" INPUT "${INPUT}")
string(REGEX REPLACE "\\$<0:[^<>]*>" "" INPUT "${INPUT}")
string(REGEX REPLACE "\\$<1:([^<>]*)>" "\\1" INPUT "${INPUT}")
string(GENEX_STRIP "${INPUT}" INPUT)
set(${OUTPUT_VAR} "${INPUT}" PARENT_SCOPE)
endfunction()
function(get_target_property2 VAR TARGET PROPERTY) function(get_target_property2 VAR TARGET PROPERTY)
get_target_property(_pflags ${TARGET} ${PROPERTY}) get_target_property(_pflags ${TARGET} ${PROPERTY})
if(_pflags) if(_pflags)
eval_and_strip_genex(_pflags "${_pflags}")
set(${VAR} ${_pflags} PARENT_SCOPE) set(${VAR} ${_pflags} PARENT_SCOPE)
else() else()
set(${VAR} "" PARENT_SCOPE) set(${VAR} "" PARENT_SCOPE)
endif() endif()
endfunction() endfunction()
function(flags_requires_arg OUTPUT_VAR FLAG)
set(_args -x -isystem)
if(FLAG IN_LIST _args)
set(${OUTPUT_VAR} 1 PARENT_SCOPE)
else()
set(${OUTPUT_VAR} 0 PARENT_SCOPE)
endif()
endfunction()
macro(append_flags FLAGS TARGET PROPERTY PREFIX) macro(append_flags FLAGS TARGET PROPERTY PREFIX)
get_target_property2(_pflags ${TARGET} ${PROPERTY}) get_target_property2(_pflags ${TARGET} ${PROPERTY})
set(_requires_arg 0)
foreach(FLAG ${_pflags}) foreach(FLAG ${_pflags})
if(TARGET ${FLAG}) string(STRIP "${FLAG}" FLAG)
if(FLAG)
if(TARGET ${FLAG} AND NOT _requires_arg)
target_flags(_pflags2 ${FLAG}) target_flags(_pflags2 ${FLAG})
string(APPEND ${FLAGS} " ${_pflags2}") string(APPEND ${FLAGS} " ${_pflags2}")
else() else()
string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") string(APPEND ${FLAGS} " ${PREFIX}${FLAG}")
endif() endif()
flags_requires_arg(_requires_arg "${FLAG}")
endif()
endforeach() endforeach()
endmacro() endmacro()
macro(append_link_flags FLAGS TARGET PROPERTY) macro(append_link_flags FLAGS TARGET PROPERTY)
get_target_property2(_pflags ${TARGET} ${PROPERTY}) get_target_property2(_pflags ${TARGET} ${PROPERTY})
set(_requires_arg 0)
foreach(FLAG ${_pflags}) foreach(FLAG ${_pflags})
if(TARGET ${FLAG}) string(STRIP "${FLAG}" FLAG)
if(FLAG)
if(TARGET ${FLAG} AND NOT _requires_arg)
target_flags(_pflags2 ${FLAG}) target_flags(_pflags2 ${FLAG})
string(APPEND ${FLAGS} " ${_pflags2}") string(APPEND ${FLAGS} " ${_pflags2}")
elseif(FLAG MATCHES "^-.*") elseif(FLAG MATCHES "^-.*")
...@@ -34,6 +67,8 @@ macro(append_link_flags FLAGS TARGET PROPERTY) ...@@ -34,6 +67,8 @@ macro(append_link_flags FLAGS TARGET PROPERTY)
else() else()
string(APPEND ${FLAGS} " -l${FLAG}") string(APPEND ${FLAGS} " -l${FLAG}")
endif() endif()
flags_requires_arg(_requires_arg "${FLAG}")
endif()
endforeach() endforeach()
endmacro() endmacro()
......
...@@ -24,7 +24,6 @@ int main(int argc, char** argv) ...@@ -24,7 +24,6 @@ int main(int argc, char** argv)
return 0; return 0;
} }
char* parse_arg = getCmdOption(argv + 2, argv + argc, "--parse");
char* load_arg = getCmdOption(argv + 2, argv + argc, "--load"); char* load_arg = getCmdOption(argv + 2, argv + argc, "--load");
char* save_arg = getCmdOption(argv + 2, argv + argc, "--save"); char* save_arg = getCmdOption(argv + 2, argv + argc, "--save");
const char* input_file = argv[1]; const char* input_file = argv[1];
......
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -51,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m) ...@@ -51,6 +52,7 @@ cpp_generator::function& cpp_generator::function::set_types(const module& m)
cpp_generator::function& cpp_generator::function&
cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse) cpp_generator::function::set_types(const module& m, const std::function<std::string(shape)>& parse)
{ {
this->params.clear();
auto pmap = m.get_parameter_shapes(); auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end()); std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform( std::transform(
...@@ -63,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str ...@@ -63,11 +65,30 @@ cpp_generator::function::set_types(const module& m, const std::function<std::str
return *this; return *this;
} }
cpp_generator::function& cpp_generator::function::set_generic_types(const module& m)
{
this->params.clear();
auto pmap = m.get_parameter_shapes();
std::map<std::string, shape> input_map(pmap.begin(), pmap.end());
std::transform(
input_map.begin(), input_map.end(), std::back_inserter(this->params), [&](auto&& p) {
return param{p.first, "T" + p.first};
});
std::transform(input_map.begin(),
input_map.end(),
std::back_inserter(this->tparams),
[&](auto&& p) { return "class T" + p.first; });
this->return_type = "auto";
return *this;
}
struct cpp_generator_impl struct cpp_generator_impl
{ {
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -83,15 +104,28 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -83,15 +104,28 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{
impl->point_op_map[op_name] = code;
}
std::string cpp_generator::generate_point_op(const operation& op, std::string cpp_generator::generate_point_op(const operation& op,
const std::vector<std::string>& args) const std::vector<std::string>& args)
{ {
auto v = op.to_value(); auto v = op.to_value();
std::string code;
if(contains(impl->point_op_map, op.name()))
{
code = impl->point_op_map.at(op.name());
}
else
{
auto attributes = op.attributes(); auto attributes = op.attributes();
if(not attributes.contains("point_op")) if(not attributes.contains("point_op"))
MIGRAPHX_THROW("op is missing point_op attribute: " + op.name()); MIGRAPHX_THROW("op is missing point_op attribute: " + op.name());
return interpolate_string(attributes["point_op"].to<std::string>(), code = attributes["point_op"].to<std::string>();
[&](auto start, auto last) -> std::string { }
return interpolate_string(code, [&](auto start, auto last) -> std::string {
auto key = trim({start, last}); auto key = trim({start, last});
if(key.empty()) if(key.empty())
MIGRAPHX_THROW("Empty parameter"); MIGRAPHX_THROW("Empty parameter");
...@@ -148,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -148,6 +182,8 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
std::string cpp_generator::create_function(const cpp_generator::function& f) std::string cpp_generator::create_function(const cpp_generator::function& f)
{ {
impl->function_count++; impl->function_count++;
if(not f.tparams.empty())
impl->fs << "template<" << join_strings(f.tparams, ", ") << ">\n";
std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name; std::string name = f.name.empty() ? "f" + std::to_string(impl->function_count) : f.name;
impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name; impl->fs << join_strings(f.attributes, " ") << " " << f.return_type << " " << name;
char delim = '('; char delim = '(';
......
...@@ -34,6 +34,7 @@ struct cpp_generator ...@@ -34,6 +34,7 @@ struct cpp_generator
std::string return_type = "void"; std::string return_type = "void";
std::string name = ""; std::string name = "";
std::vector<std::string> attributes = {}; std::vector<std::string> attributes = {};
std::vector<std::string> tparams = {};
function& set_body(const module& m, const generate_module_callback& g); function& set_body(const module& m, const generate_module_callback& g);
function& set_body(const std::string& s) function& set_body(const std::string& s)
{ {
...@@ -52,6 +53,7 @@ struct cpp_generator ...@@ -52,6 +53,7 @@ struct cpp_generator
} }
function& set_types(const module& m); function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse); function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
}; };
cpp_generator(); cpp_generator();
...@@ -66,6 +68,8 @@ struct cpp_generator ...@@ -66,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
std::string str() const; std::string str() const;
......
...@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x, ...@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
return x.index - y.index; return x.index - y.index;
} }
template <class F, class Iterator>
inline basic_iota_iterator<F, Iterator> operator-(basic_iota_iterator<F, Iterator> x,
std::ptrdiff_t y)
{
return x -= y;
}
template <class F, class Iterator> template <class F, class Iterator>
inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y) inline bool operator==(basic_iota_iterator<F, Iterator> x, basic_iota_iterator<F, Iterator> y)
{ {
......
...@@ -35,7 +35,7 @@ struct argmax ...@@ -35,7 +35,7 @@ struct argmax
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -35,7 +35,7 @@ struct argmin ...@@ -35,7 +35,7 @@ struct argmin
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
lens[axis] = 1; lens[axis] = 1;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -21,25 +21,26 @@ struct clip ...@@ -21,25 +21,26 @@ struct clip
{ {
std::string name() const { return "clip"; } std::string name() const { return "clip"; }
value attributes() const
{
return {{"pointwise", true},
{"point_op", "${function:min}(${function:max}(${1}, ${0}), ${2})"}};
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(3).same_type(); check_shapes{inputs, *this}.has(3).same_type().same_dims();
return inputs.front(); return inputs.front();
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
visit_all(result, args[0], args[1], args[2])([&](auto output, auto x, auto min, auto max) {
visit_all(result, args[0], args[1], args[2])( par_for(output_shape.elements(),
[&](auto output, auto input, auto min_val, auto max_val) { [&](auto i) { output[i] = std::min(std::max(min[i], x[i]), max[i]); });
auto max = max_val.front();
auto min = min_val.front();
std::transform(input.begin(), input.end(), output.begin(), [max, min](auto x) {
using type = decltype(x);
return std::min(std::max(type(min), x), type(max));
});
}); });
return result; return result;
} }
}; };
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_greaterorequal : op_parser<parse_greaterorequal>
{
std::vector<op_desc> operators() const { return {{"GreaterOrEqual"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto in_res = info.add_broadcastable_binary_op("less", args[0], args[1]);
if(in_res->get_shape().type() != shape::bool_type)
{
in_res = info.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}),
in_res);
}
return info.add_instruction(make_op("not"), in_res);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_hardsigmoid : op_parser<parse_hardsigmoid>
{
std::vector<op_desc> operators() const { return {{"HardSigmoid"}, {"HardSwish"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
float alpha = 0.2;
float beta = 0.5;
if(opd.onnx_name == "HardSwish")
{
alpha = 1.0 / 6.0;
}
else
{
if(contains(info.attributes, "alpha"))
alpha = info.attributes.at("alpha").f();
if(contains(info.attributes, "beta"))
beta = info.attributes.at("beta").f();
}
auto input_lens = args[0]->get_shape().lens();
auto input_type = args[0]->get_shape().type();
auto mb_alpha = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {alpha}}));
auto mb_beta = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {beta}}));
auto mb_zero = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {0}}));
auto mb_one = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}),
info.add_literal(migraphx::literal{migraphx::shape{input_type}, {1}}));
auto mul = info.add_instruction(migraphx::make_op("mul"), mb_alpha, args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), mb_beta, mul);
auto hardsigmoid = info.add_instruction(migraphx::make_op("clip"), add, mb_zero, mb_one);
if(opd.onnx_name == "HardSwish")
return info.add_instruction(migraphx::make_op("mul"), args[0], hardsigmoid);
return hardsigmoid;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -27,11 +27,6 @@ struct parse_multinomial : op_parser<parse_multinomial>
if(contains(info.attributes, "sample_size")) if(contains(info.attributes, "sample_size"))
sample_size = info.attributes.at("sample_size").i(); sample_size = info.attributes.at("sample_size").i();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
// Subtract the per-batch maximum log-probability, making the per-batch max 0 // Subtract the per-batch maximum log-probability, making the per-batch max 0
auto maxes = auto maxes =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]); info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {1}}}), args[0]);
...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial> ...@@ -46,7 +41,10 @@ struct parse_multinomial : op_parser<parse_multinomial>
migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf); migraphx::make_op("prefix_scan_sum", {{"axis", 1}, {"exclusive", false}}), cdf);
// Pre-compute random distribution // Pre-compute random distribution
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> dis(0.0, 1.0); std::uniform_real_distribution<> dis(0.0, 1.0);
size_t batch_size = args[0]->get_shape().lens().front(); size_t batch_size = args[0]->get_shape().lens().front();
migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}}; migraphx::shape dist_shape{migraphx::shape::float_type, {batch_size, sample_size}};
......
...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -42,11 +42,6 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
if(contains(info.attributes, "scale")) if(contains(info.attributes, "scale"))
scale = info.attributes.at("scale").f(); scale = info.attributes.at("scale").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops> ...@@ -75,7 +70,10 @@ struct parse_randomnormal_ops : op_parser<parse_randomnormal_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::normal_distribution<> d(mean, scale); std::normal_distribution<> d(mean, scale);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -42,11 +42,6 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
if(contains(info.attributes, "low")) if(contains(info.attributes, "low"))
low = info.attributes.at("low").f(); low = info.attributes.at("low").f();
float seed = static_cast<float>(
std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
seed = info.attributes.at("seed").f();
shape out_shape; shape out_shape;
if(contains(info.attributes, "shape")) if(contains(info.attributes, "shape"))
{ {
...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops> ...@@ -75,7 +70,10 @@ struct parse_randomuniform_ops : op_parser<parse_randomuniform_ops>
": cannot deduce shape without shape attribute or argument."); ": cannot deduce shape without shape attribute or argument.");
} }
std::mt19937 gen(seed); std::mt19937 gen(std::chrono::high_resolution_clock::now().time_since_epoch().count());
if(contains(info.attributes, "seed"))
gen.seed(info.attributes.at("seed").f());
std::uniform_real_distribution<> d(high, low); std::uniform_real_distribution<> d(high, low);
std::vector<double> rand_vals(out_shape.elements()); std::vector<double> rand_vals(out_shape.elements());
std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); }); std::generate(rand_vals.begin(), rand_vals.end(), [&]() { return d(gen); });
......
...@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
struct parse_resize : op_parser<parse_resize> struct parse_resize : op_parser<parse_resize>
{ {
std::vector<op_desc> operators() const { return {{"Resize"}}; } std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
...@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize> ...@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize>
if(contains(info.attributes, "exclude_outside") and if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1) info.attributes.at("exclude_outside").i() == 1)
{ {
MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!"); MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
} }
// input data shape info // input data shape info
...@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize>
if(type == shape::int64_type) if(type == shape::int64_type)
{ {
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!"); check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size"); MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
} }
// compute the scale // compute the scale
...@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize>
{ {
auto arg_scale = arg->eval(); auto arg_scale = arg->eval();
check_arg_empty(arg_scale, check_arg_empty(arg_scale,
"PARSE_RESIZE: dynamic input scale is not supported!"); "PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!"); MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
} }
std::transform(in_lens.begin(), std::transform(in_lens.begin(),
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softplus : op_parser<parse_softplus>
{
std::vector<op_desc> operators() const { return {{"Softplus"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = ln(exp(x) + 1)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto exp = info.add_instruction(migraphx::make_op("exp"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), exp, mb_ones);
return info.add_instruction(migraphx::make_op("log"), add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_softsign : op_parser<parse_softsign>
{
std::vector<op_desc> operators() const { return {{"Softsign"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
// Apply pointwise formula: y = x / (1 + |x|)
auto mb_ones = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", args[0]->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {1}}));
auto abs = info.add_instruction(migraphx::make_op("abs"), args[0]);
auto add = info.add_instruction(migraphx::make_op("add"), abs, mb_ones);
return info.add_instruction(migraphx::make_op("div"), args[0], add);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_upsample : op_parser<parse_upsample>
{
std::vector<op_desc> operators() const { return {{"Upsample"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = info.add_literal(literal(ind_s, ind));
return info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment