"sgl-router/vscode:/vscode.git/clone" did not exist on "a18161875c8a0c2a4a430c01a655f63b2fd97ab1"
Unverified Commit 3e0888eb authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_reshape

parents f6e4ff35 c6efdf8c
name: Add items to GH project
on:
pull_request:
types:
- opened
issues:
types:
- opened
jobs:
add-to-project:
name: Add PRs and issues to MIGX project
runs-on: ubuntu-latest
steps:
- uses: actions/add-to-project@v0.4.0
with:
project-url: https://github.com/orgs/ROCmSoftwarePlatform/projects/20
github-token: ${{ secrets.TEST_PR_WORKFLOW }}
......@@ -110,9 +110,19 @@ instruction_ref onnx_parser::node_info::add_bias(const std::vector<instruction_r
{
if(args.size() == 3)
{
auto bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
instruction_ref bias_bcast;
// if curr_ins has a dynamic output shape use 2 input broadcast
if(curr_ins->get_shape().dynamic())
{
bias_bcast =
mod->add_instruction(make_op("broadcast", {{"axis", axis}}), args[2], curr_ins);
}
else
{
bias_bcast = mod->add_instruction(
make_op("broadcast", {{"axis", axis}, {"out_lens", curr_ins->get_shape().lens()}}),
args[2]);
}
return mod->add_instruction(make_op("add"), curr_ins, bias_bcast);
}
return curr_ins;
......
......@@ -43,55 +43,79 @@ struct parse_matmul : op_parser<parse_matmul>
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto l0 = args[0];
auto l1 = args[1];
auto l0_lens = l0->get_shape().lens();
auto l1_lens = l1->get_shape().lens();
auto a0 = args[0];
auto a1 = args[1];
auto s0 = a0->get_shape();
auto s1 = a1->get_shape();
// args[0] is a vector, prepend 1 to the shape
instruction_ref dot_res;
bool is_a_prepended = false;
if(l0_lens.size() == 1)
bool is_b_appended = false;
if(s0.ndim() == 1)
{
is_a_prepended = true;
l0_lens.insert(l0_lens.begin(), 1);
l0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
a0 = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), args[0]);
}
bool is_b_appended = false;
if(l1_lens.size() == 1)
if(s1.ndim() == 1)
{
is_b_appended = true;
l1_lens.push_back(1);
l1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
}
instruction_ref bl0 = l0;
instruction_ref bl1 = l1;
if(not std::equal(
l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
if(s0.dynamic() or s1.dynamic())
{
auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
if(l0_lens != l0_broadcasted_lens)
if(opd.op_name == "quant_dot")
{
MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported");
}
auto s0_dds = a0->get_shape().to_dynamic().dyn_dims();
auto s1_dds = a1->get_shape().to_dynamic().dyn_dims();
// TODO: handling this case requires a new multibroadcast mode
if(not std::equal(
s0_dds.rbegin() + 2, s0_dds.rend(), s1_dds.rbegin() + 2, s1_dds.rend()))
{
bl0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), l0);
MIGRAPHX_THROW("PARSE_MATMUL: dynamic shape broadcasting not supported");
}
if(l1_lens != l1_broadcasted_lens)
dot_res = info.add_instruction(make_op(opd.op_name), a0, a1);
}
else
{
auto s0_lens = a0->get_shape().lens();
auto s1_lens = a1->get_shape().lens();
instruction_ref ba0 = a0;
instruction_ref ba1 = a1;
// try broadcasting if dimensions other than last two do not match
if(not std::equal(
s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend()))
{
bl1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), l1);
auto l0_it = s0_lens.begin() + s0_lens.size() - 2;
std::vector<std::size_t> l0_broadcasted_lens(s0_lens.begin(), l0_it);
auto l1_it = s1_lens.begin() + s1_lens.size() - 2;
std::vector<std::size_t> l1_broadcasted_lens(s1_lens.begin(), l1_it);
auto output_lens =
compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
l0_broadcasted_lens = output_lens;
l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end());
l1_broadcasted_lens = output_lens;
l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end());
if(s0_lens != l0_broadcasted_lens)
{
ba0 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0);
}
if(s1_lens != l1_broadcasted_lens)
{
ba1 = info.add_instruction(
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
}
}
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
}
instruction_ref dot_res = info.add_instruction(make_op(opd.op_name), bl0, bl1);
int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
// squeeze the appended or prepended dimensions
int64_t num_axis = dot_res->get_shape().ndim();
if(is_a_prepended)
{
dot_res = info.add_instruction(make_op("squeeze", {{"axes", {num_axis - 2}}}), dot_res);
......
......@@ -1065,11 +1065,23 @@ struct find_split_reshape
return;
}
// Only want to apply this optimization if each split output is followed by
// a contiguous op and a reshape
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
if(i->outputs().size() == 1)
{
auto cont = i->outputs().front();
return cont->outputs().size() != 1;
}
return false;
}))
{
return;
}
std::vector<instruction_ref> vec_rsp(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) {
assert(i->outputs().size() == 1);
auto cont = i->outputs().front();
assert(cont->outputs().size() == 1);
return cont->outputs().front();
});
......
......@@ -763,16 +763,23 @@ struct find_transpose_slice
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
// Make unsqueeze
std::vector<int64_t> steps(sdistance.size());
std::transform(
slice.axes.begin(),
slice.axes.end(),
sdistance.begin(),
steps.begin(),
[&](const auto ax, const auto sdis) { return ins->get_shape().lens().at(ax) / sdis; });
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", steps}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
if(i >= preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
perm.insert(perm.begin(), preaxis);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gather_kernel = R"__migraphx__(
#include <migraphx/kernels/gather.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gather_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
gather<${axis}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gather_compiler : compiler<gather_compiler>
{
std::vector<std::string> names() const { return {"gather"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
const auto& out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gather_kernel";
options.virtual_inputs = inputs;
auto axis = v.at("axis").to<std::string>();
auto src = interpolate_string(gather_kernel, {{"axis", axis}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHER_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/shape.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/tensor_view.hpp>
namespace migraphx {
template <int Axis, class Input, class Indices>
constexpr auto gather_shape(Input input, Indices indices)
{
auto lengths = input.lens;
lengths[Axis] = indices.elements();
return make_shape(lengths, input.strides);
}
template <int Axis, class Input, class Indices, class Output>
__device__ void gather(Input input, Indices indices, Output output)
{
auto ind = make_index();
auto axis_dim_size = input.get_shape().lens[Axis];
constexpr auto out_comp = gather_shape<Axis>(get_shape_c<Input>{}, get_shape_c<Indices>{});
ind.global_stride(output.get_shape().elements(), [&](auto i) {
auto idx = out_comp.multi(i);
auto in_index = indices[idx[Axis]];
auto new_in_index = (in_index < 0) ? in_index + axis_dim_size : in_index;
idx[Axis] = new_in_index;
output[i] = input[idx];
});
}
} // namespace migraphx
#endif
......@@ -132,9 +132,14 @@ MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, ceil, ::hceil)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, cos, ::hcos)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, exp, ::hexp)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, isnan, ::__hisnan)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, log, ::hlog)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, rsqrt, ::hrsqrt)
// MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sqrt, ::hsqrt)
// Use float to compute half overload
......@@ -144,16 +149,11 @@ MIGRAPHX_DEVICE_MATH_HALF(asin, ::asin)
MIGRAPHX_DEVICE_MATH_HALF(asinh, ::asinh)
MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF(atanh, ::atanh)
MIGRAPHX_DEVICE_MATH_HALF(ceil, ::ceil)
MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
......@@ -166,19 +166,19 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// at this time are: exp2, exp10, log2, log10, isinf
MIGRAPHX_DEVICE_MATH_HALF2(abs, ::__habs2)
MIGRAPHX_DEVICE_MATH_HALF2(ceil, ::h2ceil)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(cos, ::h2cos)
MIGRAPHX_DEVICE_MATH_HALF2(exp, ::h2exp)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(exp10, ::h2exp10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(exp2, ::h2exp2)
MIGRAPHX_DEVICE_MATH_HALF2(floor, ::h2floor)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2(log10, ::h2log10)
MIGRAPHX_DEVICE_MATH_HALF2(log2, ::h2log2)
MIGRAPHX_DEVICE_MATH_HALF2(rsqrt, ::h2rsqrt)
// MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2(sqrt, ::h2sqrt)
MIGRAPHX_DEVICE_MATH_HALF2(isinf, ::__hisinf2)
MIGRAPHX_DEVICE_MATH_HALF2(isnan, ::__hisnan2)
template <class T, class U>
constexpr auto where(bool cond, const T& a, const U& b)
......@@ -218,6 +218,14 @@ constexpr auto min(const T& a, const U& b)
return min<common_type_t<T, U>>(a, b);
}
// Sin for half is broken on hip, so use cos instead
template <class T, MIGRAPHX_REQUIRES(is_same<vec_type<T>, half>{})>
constexpr T sin(T x)
{
constexpr const T shift = M_PI_2;
return migraphx::cos(shift - x);
}
MIGRAPHX_DEVICE_MATH_VEC(abs)
MIGRAPHX_DEVICE_MATH_VEC(acos)
MIGRAPHX_DEVICE_MATH_VEC(acosh)
......
......@@ -128,6 +128,7 @@ struct shape
result[0] = tidx;
return result;
}
/// Convert multi-index into a single index
constexpr index_int single(index_array idx) const
{
......
......@@ -90,7 +90,6 @@ struct miopen_apply
add_extend_op("argmax");
add_extend_op("argmin");
add_extend_op("gather");
add_extend_op("logsoftmax");
add_extend_op("lrn");
add_extend_op("multinomial");
......
......@@ -1121,6 +1121,24 @@ def conv_dynamic_batch_test():
return ([node], [x, y], [out])
@onnx_test()
def conv_dynamic_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
[None, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1])
out = helper.make_tensor_value_info('3', TensorProto.FLOAT,
[None, 2, 28, 28])
node = onnx.helper.make_node('Conv',
inputs=['0', '1', '2'],
outputs=['3'],
dilations=[1, 1],
strides=[1, 1])
return ([node], [x, y, z], [out])
@onnx_test()
def conv_dynamic_img_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
......@@ -3545,6 +3563,81 @@ def matmul_vv_test():
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_mm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_mv_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None, 1])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_vm_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_vv_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmul_dyn_broadcast_error():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [7])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [5, 7, None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [5, None])
node = onnx.helper.make_node(
'MatMul',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def matmulinteger_test():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [3, 6, 16])
......@@ -3560,6 +3653,21 @@ def matmulinteger_test():
return ([node], [m1, m2], [y])
@onnx_test()
def matmulinteger_dyn_error():
m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [None, 6, 16])
m2 = helper.make_tensor_value_info('2', TensorProto.INT8, [None, 16, 8])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [None, 6, 8])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2'],
outputs=['y'],
)
return ([node], [m1, m2], [y])
@onnx_test()
def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
......@@ -1118,6 +1118,25 @@ TEST_CASE(conv_dynamic_batch_test)
EXPECT(p == prog);
}
TEST_CASE(conv_dynamic_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 6, 0}, {3, 3, 0}, {32, 32, 0}, {32, 32, 0}}});
auto x1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 5, 5}});
auto x2 = mm->add_parameter("2", {migraphx::shape::float_type, {1}});
auto x3 = mm->add_instruction(migraphx::make_op("convolution"), x0, x1);
auto x4 = mm->add_instruction(migraphx::make_op("broadcast", {{"axis", 1}}), x2, x3);
auto x5 = mm->add_instruction(migraphx::make_op("add"), x3, x4);
mm->add_return({x5});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 6, 0};
auto prog = migraphx::parse_onnx("conv_dynamic_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(conv_dynamic_img_test)
{
migraphx::program p;
......@@ -3413,6 +3432,92 @@ TEST_CASE(matmul_vv_test)
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_mm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {1, 5, 3}}});
auto ret = migraphx::add_apply_alpha_beta(*mm, {l0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {1, 5, 3}};
auto prog = parse_onnx("matmul_dyn_mm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_mv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"1", migraphx::shape{migraphx::shape::float_type, {{4, 8, 6}, {7, 7, 0}}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {7}});
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res = migraphx::add_apply_alpha_beta(*mm, {l0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["1"] = {{4, 8, 6}, {7, 7, 0}};
auto prog = parse_onnx("matmul_dyn_mv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vm_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7}});
auto l1 = mm->add_parameter(
"2", migraphx::shape{migraphx::shape::float_type, {{7, 7, 0}, {4, 10, 8}}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto res = migraphx::add_apply_alpha_beta(*mm, {sl0, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
mm->add_return({ret});
migraphx::onnx_options options;
options.map_dyn_input_dims["2"] = {{7, 7, 0}, {4, 10, 8}};
auto prog = parse_onnx("matmul_dyn_vm_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_vv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{5, 8, 7};
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {dd}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {dd}});
auto sl0 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), l0);
auto sl1 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), l1);
auto res =
migraphx::add_apply_alpha_beta(*mm, {sl0, sl1}, migraphx::make_op("dot"), 1.0f, 0.0f);
auto sr0 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), res);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), sr0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
auto prog = parse_onnx("matmul_dyn_vv_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(matmul_dyn_broadcast_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmul_dyn_broadcast_error.onnx", options); }));
}
TEST_CASE(matmulinteger_test)
{
migraphx::program p;
......@@ -3426,6 +3531,13 @@ TEST_CASE(matmulinteger_test)
EXPECT(p == prog);
}
TEST_CASE(matmulinteger_dyn_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_dyn_error.onnx", options); }));
}
TEST_CASE(max_test)
{
migraphx::program p;
......
......@@ -94,6 +94,16 @@ def disabled_tests_onnx_1_8_1(backend_test):
backend_test.exclude(r'test_unsqueeze_unsorted_axes_cpu')
def disabled_tests_onnx_1_10_0(backend_test):
# unsupported shape attributes
backend_test.exclude(r'test_shape_end_1_cpu')
backend_test.exclude(r'test_shape_end_negative_1_cpu')
backend_test.exclude(r'test_shape_start_1_cpu')
backend_test.exclude(r'test_shape_start_1_end_2_cpu')
backend_test.exclude(r'test_shape_start_1_end_negative_1_cpu')
backend_test.exclude(r'test_shape_start_negative_1_cpu')
def create_backend_test(testname=None, target_device=None):
if target_device is not None:
c2.set_device(target_device)
......@@ -314,6 +324,9 @@ def create_backend_test(testname=None, target_device=None):
if version.parse(onnx.__version__) >= version.parse("1.8.0"):
disabled_tests_onnx_1_8_1(backend_test)
if version.parse(onnx.__version__) >= version.parse("1.10.0"):
disabled_tests_onnx_1_10_0(backend_test)
# import all test cases at global scope to make
# them visible to python.unittest.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment