Unverified Commit 3ae5f9ed authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into concat2

parents 6d5a34d2 785ff7d7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_ONNX_POOLING_HPP
#include <migraphx/config.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
value handle_pooling_values(const op_desc& opd,
onnx_parser::node_info info,
const shape& in_shape,
value values);
instruction_ref add_pooling_op(const op_desc& opd, onnx_parser::node_info info, instruction_ref l0);
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -60,7 +60,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Neg", "neg"},
{"Reciprocal", "recip"},
{"Relu", "relu"},
{"Round", "round"},
{"Round", "nearbyint"},
{"Sigmoid", "sigmoid"},
{"Sign", "sign"},
{"Sin", "sin"},
......
......@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void lstm_transpose_inputs(onnx_parser::node_info& info, std::vector<instruction_ref>& args)
{
std::vector<int64_t> perm{1, 0, 2};
args[0] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[0]);
if(args.size() >= 6 and not args[5]->is_undefined())
{
args[5] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[5]);
}
if(args.size() >= 7 and not args[6]->is_undefined())
{
args[6] = info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[6]);
}
}
void lstm_transpose_outputs(onnx_parser::node_info& info,
instruction_ref& hidden_states,
instruction_ref& last_output,
instruction_ref& last_cell_output)
{
std::vector<int64_t> perm_hs{2, 0, 1, 3};
hidden_states =
info.add_instruction(make_op("transpose", {{"permutation", perm_hs}}), hidden_states);
std::vector<int64_t> perm_last{1, 0, 2};
last_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_output);
last_cell_output =
info.add_instruction(make_op("transpose", {{"permutation", perm_last}}), last_cell_output);
}
struct parse_lstm : op_parser<parse_lstm>
{
std::vector<op_desc> operators() const { return {{"LSTM"}}; }
......@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget = parser.parse_value(info.attributes.at("input_forget")).at<int>();
}
int layout = 0;
if(contains(info.attributes, "layout"))
{
layout = parser.parse_value(info.attributes.at("layout")).at<int>();
}
// append undefined opeator to make 6 arguments
if(args.size() < 8)
{
......@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args.insert(args.end(), 8 - args.size(), ins);
}
if(layout != 0)
{
lstm_transpose_inputs(info, args);
}
// first output for concatenation of hidden states
auto hidden_states = info.add_instruction(make_op("lstm",
{{"hidden_size", hidden_size},
......@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto last_cell_output =
info.add_instruction(make_op("rnn_last_cell_output"), hidden_states);
if(layout != 0)
{
lstm_transpose_outputs(info, hidden_states, last_output, last_cell_output);
}
return {hidden_states, last_output, last_cell_output};
}
};
......
......@@ -127,9 +127,9 @@ struct parse_multinomial : op_parser<parse_multinomial>
// use literal. The array populated by random_uniform may have any shape, as long its
// number of elements is batch_size * sample_size .
size_t batch_size = s0.lens().front();
auto rand_dummy = info.add_literal(
migraphx::literal{migraphx::shape::float_type, {batch_size * sample_size}});
auto rand_dummy = info.add_literal(migraphx::literal{
migraphx::shape{migraphx::shape::float_type, {batch_size, sample_size}},
std::vector<float>(batch_size * sample_size)});
randoms =
info.add_instruction(migraphx::make_op("random_uniform"), seed_input, rand_dummy);
}
......
This diff is collapsed.
......@@ -36,7 +36,7 @@ namespace onnx {
/*
*********************************************************************************
* Reference: see QLinearAdd in *
* Reference: see QLinearAdd, QLinearMul in *
* https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md *
*********************************************************************************
......@@ -49,6 +49,17 @@ namespace onnx {
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
com.microsoft.QLinearMul
Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting
support).
C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point
Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator
set.
General definition of binary QLinear* ops:
Inputs (7 - 8)
A : T
First operand.
......@@ -88,15 +99,18 @@ namespace onnx {
*/
struct parse_qlinearadd : op_parser<parse_qlinearadd>
struct parse_qlinearbinary : op_parser<parse_qlinearbinary>
{
std::vector<op_desc> operators() const { return {{"QLinearAdd"}}; }
std::vector<op_desc> operators() const
{
return {{"QLinearAdd", "add"}, {"QLinearMul", "mul"}};
}
// basic type checking for QLinearAdd Operator
void check_inputs(const std::vector<instruction_ref>& args) const
// basic type checking for binary QLinear Operator
void check_inputs(const std::vector<instruction_ref>& args, const std::string& op_name) const
{
if(args.size() < 7)
MIGRAPHX_THROW("QLINEARADD: missing inputs");
MIGRAPHX_THROW(op_name + ": missing inputs");
const auto& in_a = args[0];
const auto& in_b = args[3];
......@@ -107,19 +121,19 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
auto type_a = sh_a.type();
auto type_b = sh_b.type();
if(type_a != migraphx::shape::int8_type and type_a != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_b != migraphx::shape::int8_type and type_b != migraphx::shape::uint8_type)
MIGRAPHX_THROW("QLINEARADD: unsupported input type");
MIGRAPHX_THROW(op_name + ": unsupported input type");
if(type_a != type_b)
MIGRAPHX_THROW("QLINEARADD: mismatched input types");
MIGRAPHX_THROW(op_name + ": mismatched input types");
}
instruction_ref parse(const op_desc& /* opd */,
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
check_inputs(args);
check_inputs(args, opd.op_name);
// A
const auto& in_a = args[0];
......@@ -134,8 +148,8 @@ struct parse_qlinearadd : op_parser<parse_qlinearadd>
const auto& in_zero_pt_b = args[5];
auto dquant_b = bcast_qdq_instr("dequantizelinear", in_b, in_scale_b, in_zero_pt_b, info);
// C = A + B
auto out_c = info.add_common_op("add", dquant_a, dquant_b);
// C = op(A, B)
auto out_c = info.add_common_op(opd.op_name, dquant_a, dquant_b);
const auto& in_scale_c = args[6];
......
This diff is collapsed.
This diff is collapsed.
......@@ -39,15 +39,17 @@ struct parse_scatternd : op_parser<parse_scatternd>
const onnx_parser::node_info& info,
std::vector<instruction_ref>& args) const
{
std::string reduction = "none";
if(contains(info.attributes, "reduction"))
{
if(info.attributes.at("reduction").s() == "add")
return info.add_instruction(migraphx::make_op("scatternd_add"), args);
if(info.attributes.at("reduction").s() == "mul")
return info.add_instruction(migraphx::make_op("scatternd_mul"), args);
reduction = info.attributes.at("reduction").s();
if(not contains({"none", "add", "mul", "min", "max"}, reduction))
{
MIGRAPHX_THROW("PARSE_SCATTERND: unsupported reduction mode " + reduction);
}
}
return info.add_instruction(migraphx::make_op("scatternd_none"), args);
return info.add_instruction(migraphx::make_op("scatternd_" + reduction), args);
}
};
......
......@@ -46,6 +46,9 @@ struct parse_slice : op_parser<parse_slice>
void always_insert(instruction_ref arg) { op_args.insert(op_args.begin(), arg); }
/**
* Either insert argument into `this->op_args` or return the constant value of the argument
*/
std::vector<int64_t> insert(instruction_ref arg)
{
std::vector<int64_t> result;
......@@ -144,16 +147,15 @@ struct parse_slice : op_parser<parse_slice>
sd.op.axes = axes;
}
if(not sd.steps.empty())
if(std::any_of(sd.steps.begin(), sd.steps.end(), [](auto s) { return s != 1; }))
{
if(sd.op.starts.empty() or sd.op.ends.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable starts and ends is not supported");
MIGRAPHX_THROW(
"PARSE_SLICE: steps and variable starts and/or ends is not supported");
if(sd.op.axes.empty())
MIGRAPHX_THROW("PARSE_SLICE: steps and variable axes is not supported");
}
assert(sd.steps.empty() or sd.steps.size() == sd.op.axes.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(sd.steps.size()))
{
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -56,7 +56,11 @@ target make_target(const std::string& name)
{
if(not contains(target_map(), name))
{
#ifdef _WIN32
std::string target_name = "migraphx_" + name + ".dll";
#else
std::string target_name = "libmigraphx_" + name + ".so";
#endif
store_target_lib(dynamic_loader(target_name));
}
const auto it = target_map().find(name);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment