Unverified Commit 49dd34f8 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into jit-reduce-reg

parents a5c87ec5 1e7ea2b6
......@@ -39,10 +39,19 @@ struct parse_gemm : op_parser<parse_gemm>
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
float alpha = 1.0f;
float beta = 1.0f;
bool transa = false;
bool transb = false;
auto a_arg = args[0];
auto b_arg = args[1];
if(a_arg->get_shape().ndim() != 2 or b_arg->get_shape().ndim() != 2)
{
MIGRAPHX_THROW("PARSE_GEMM: A and B should be rank 2, A is rank " +
std::to_string(a_arg->get_shape().ndim()) + ", B is rank " +
std::to_string(b_arg->get_shape().ndim()));
}
float alpha = 1.0f;
float beta = 1.0f;
bool trans_a = false;
bool trans_b = false;
if(contains(info.attributes, "alpha"))
{
alpha = parser.parse_value(info.attributes.at("alpha")).at<float>();
......@@ -53,65 +62,73 @@ struct parse_gemm : op_parser<parse_gemm>
}
if(contains(info.attributes, "transA"))
{
transa = parser.parse_value(info.attributes.at("transA")).at<bool>();
trans_a = parser.parse_value(info.attributes.at("transA")).at<bool>();
}
if(contains(info.attributes, "transB"))
{
transb = parser.parse_value(info.attributes.at("transB")).at<bool>();
trans_b = parser.parse_value(info.attributes.at("transB")).at<bool>();
}
std::vector<int64_t> perm(args[0]->get_shape().lens().size());
std::iota(perm.begin(), perm.end(), int64_t{0});
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
std::vector<int64_t> perm = {1, 0};
auto dot_type = a_arg->get_shape().type();
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
if(l1->get_shape().type() != dot_type)
a_arg = info.add_broadcastable_binary_op("mul", alpha_literal, a_arg);
if(a_arg->get_shape().type() != dot_type)
{
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
a_arg =
info.add_instruction(make_op("convert", {{"target_type", dot_type}}), a_arg);
}
}
l1 =
(transa) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), l1) : l1;
auto l2 = (transb)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
a_arg = (trans_a)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), a_arg)
: a_arg;
b_arg = (trans_b)
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), l1, l2);
auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3)
{
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
if(not float_equal(beta, 0.0f))
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
auto l3 = args[2];
auto l3_lens = l3->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
auto c_arg = args[2];
if(dot_ins->get_shape().dynamic())
{
l3 = info.add_instruction(make_op("multibroadcast", {{"out_lens", out_lens}}),
args[2]);
c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
}
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
else
{
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
}
return info.add_instruction(make_op("add"), ret, beta_l3);
if(not float_equal(beta, 1.0f))
{
auto beta_literal = info.add_literal(beta);
c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
}
return info.add_instruction(make_op("add"), dot_ins, c_arg);
}
}
return ret;
return dot_ins;
}
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_trilu : op_parser<parse_trilu>
{
std::vector<op_desc> operators() const { return {{"Trilu"}}; }
instruction_ref parse(const op_desc&,
const onnx_parser&,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto input_shape = args[0]->get_shape();
assert(input_shape.ndim() >= 2);
auto input_lens = input_shape.lens();
size_t num_rows = *(input_lens.rbegin() + 1);
size_t num_cols = input_lens.back();
int k = 0;
bool upper = true;
if(args.size() > 1)
{
auto arg_k = args[1]->eval();
check_arg_empty(arg_k, "PARSE_TRILU: dynamic k not supported");
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
}
shape::type_t output_type = args[0]->get_shape().type();
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
k++;
}
auto mask = info.add_literal(
migraphx::literal{migraphx::shape{output_type, {num_rows, num_cols}}, mask_mat});
return info.add_broadcastable_binary_op("mul", mask, args[0]);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -100,7 +100,8 @@ struct find_add_layernorm
{
auto matcher() const
{
return match::layernorm()(match::var("x")(match::name("add").bind("add")));
return match::layernorm()(
match::var("x")(match::name("add")(match::used_once()).bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
......
......@@ -66,7 +66,7 @@ TEST_CASE(load_and_run_init_list)
TEST_CASE(quantize_fp16)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
auto p1 = migraphx::parse_onnx("gemm_test.onnx");
const auto& p2 = p1;
const auto& p3 = p1;
migraphx::quantize_fp16(p1);
......@@ -82,7 +82,7 @@ TEST_CASE(quantize_fp16)
TEST_CASE(quantize_int8)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
auto p1 = migraphx::parse_onnx("gemm_test.onnx");
const auto& p2 = p1;
auto t = migraphx::target("ref");
migraphx::quantize_int8_options options;
......
No preview for this file type
......@@ -2116,71 +2116,136 @@ def gathernd_batch_dims_test():
@onnx_test()
def gemm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 7])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [11, 5])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [])
a = helper.make_tensor_value_info('3', TensorProto.FLOAT, [7, 11])
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['0', '1', '2'],
outputs=['3'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_no_C_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 7])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [7, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=2.0,
beta=2.0,
transA=1,
transB=1)
return ([node], [x, y, z], [a])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_ex_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7])
def gemm_brcst_C_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [5, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_ex_brcst_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7])
def gemm_half_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT16, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT16, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT16, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT16, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_half_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7])
def gemm_dyn_inner_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [None, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [None, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
inputs=['A', 'B'],
outputs=['Y'],
alpha=0.5,
transA=1)
return ([node], [A, B], [Y])
@onnx_test()
def gemm_dyn_outer_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [5, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [11, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 11])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B'],
outputs=['Y'],
alpha=2.0,
transA=1,
transB=1)
return ([node], [A, B], [Y])
@onnx_test()
def gemm_dyn_bias_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [None, 7])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=1.0,
beta=1.0,
transA=1)
return ([node], [A, B, C], [Y])
@onnx_test()
def gemm_rank_error():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [4, 1, 8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [4, 1, 8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [6, 7])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [4, 1, 6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
return ([node], [A, B, C], [Y])
@onnx_test()
......@@ -6706,6 +6771,92 @@ def transpose_gather_test():
return ([td, ti, node], [x, i], [y])
@onnx_test()
def trilu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
node = onnx.helper.make_node(
'Trilu',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test()
def trilu_batch_diff_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
k = np.array([2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def trilu_lower_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
node = onnx.helper.make_node('Trilu', inputs=['x'], outputs=['y'], upper=0)
return ([node], [x], [y])
@onnx_test()
def trilu_neg_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([-1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu', inputs=['x', 'k'], outputs=['y'])
return ([node], [x], [y], [k_tensor])
@onnx_test()
def trilu_out_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu', inputs=['x', 'k'], outputs=['y'])
return ([node], [x], [y], [k_tensor])
@onnx_test()
def trilu_row_one_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4])
k = np.array([1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node(
'Trilu',
inputs=['x', 'k'],
outputs=['y'],
)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def undefined_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -2135,64 +2135,64 @@ TEST_CASE(gemm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 7}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_test)
TEST_CASE(gemm_no_C_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type});
auto alpha = 2.f;
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
migraphx::make_op("multibroadcast", {{"out_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_test.onnx");
auto prog = optimize_onnx("gemm_no_C_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(gemm_ex_brcst_test)
TEST_CASE(gemm_brcst_C_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 1}});
std::vector<std::size_t> out_lens{1, 1, 6, 7};
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {5, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {6, 1}});
std::vector<std::size_t> out_lens{6, 7};
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto b_l = mm->add_literal(beta);
auto l2_b =
......@@ -2202,7 +2202,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
auto prog = optimize_onnx("gemm_brcst_C_test.onnx");
EXPECT(p == prog);
}
......@@ -2210,17 +2210,17 @@ TEST_CASE(gemm_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}});
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::half_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::half_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::half_type, {6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
......@@ -2236,6 +2236,73 @@ TEST_CASE(gemm_half_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_inner_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {6, 6, 0}}});
auto l1 = mm->add_parameter(
"B", migraphx::shape{migraphx::shape::float_type, {{1, 10, 8}, {7, 7, 0}}});
auto alpha = 0.5f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 8};
auto prog = migraphx::parse_onnx("gemm_dyn_inner_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_outer_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"A", migraphx::shape{migraphx::shape::float_type, {{5, 5, 0}, {5, 10, 7}}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto alpha = 2.f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, t1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({dot});
migraphx::onnx_options options;
options.default_dyn_dim_value = {5, 10, 7};
auto prog = migraphx::parse_onnx("gemm_dyn_outer_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
......@@ -6485,6 +6552,11 @@ TEST_CASE(transpose_gather_test)
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(trilu_neg_k_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("trilu_neg_k_test.onnx"); }));
}
TEST_CASE(undefined_test)
{
migraphx::program p;
......
trilu_batch_diff_k_test:i

x
ky"Trilutrilu_batch_diff_k_test*
:BkZ
x



b
y



B
\ No newline at end of file
trilu_neg_k_test:c

x
ky"Trilutrilu_neg_k_test*:
BkZ
x


b
y


B
\ No newline at end of file
trilu_out_k_test:Z

x
ky"Trilutrilu_out_k_test*
:BkZ
x


b
y


B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment