Unverified Commit 102c6bdb authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Dyn slice (#1503)

Add dynamic shape support to slice operator.

First draft of this feature doesn't support ops slicing non-fixed, dynamic axes. Resulting shape in such cases is not guaranteed.* Also, onnx parsing doesn't support any arguments other than "axes".
parent 67f23675
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -60,6 +61,7 @@ struct reverse ...@@ -60,6 +61,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1);
return inputs[0].with_lens(inputs[0].lens()); return inputs[0].with_lens(inputs[0].lens());
} }
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
...@@ -46,6 +47,10 @@ struct slice ...@@ -46,6 +47,10 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
/**
* Ensure that attribute vectors axes, starts, and ends are all the same size and values are in
* limits.
*/
value attributes() const value attributes() const
{ {
value normalize = value::object{}; value normalize = value::object{};
...@@ -65,14 +70,6 @@ struct slice ...@@ -65,14 +70,6 @@ struct slice
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
{
int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
if(r < 0)
r += lens[axis];
return std::size_t(r);
}
auto compute_offset(const shape& s) const auto compute_offset(const shape& s) const
{ {
const std::vector<std::size_t>& lens = s.lens(); const std::vector<std::size_t>& lens = s.lens();
...@@ -83,14 +80,14 @@ struct slice ...@@ -83,14 +80,14 @@ struct slice
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
offset += fix_index(lens, axis, starts[i]) * strides[axis]; offset += starts[i] * strides[axis];
} }
} }
else else
{ {
for(std::size_t axis = 0; axis < lens.size(); axis++) for(std::size_t axis = 0; axis < lens.size(); axis++)
{ {
offset += fix_index(lens, axis, starts[axis]) * strides[axis]; offset += starts[axis] * strides[axis];
} }
} }
return offset; return offset;
...@@ -98,37 +95,81 @@ struct slice ...@@ -98,37 +95,81 @@ struct slice
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto t = input_shape.type(); auto t = input_shape.type();
const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides();
if(std::any_of( // TODO: When support for dynamic shapes is added to normalize_attributes,
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); })) // remove this restriction.
if(input_shape.dynamic() and std::any_of(axes.begin(), axes.end(), [&](auto axis) {
return not input_shape.dyn_dims()[axis].is_fixed();
}))
{ {
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range"); MIGRAPHX_THROW("SLICE: slicing is not allowed on non-fixed dynamic input axis ");
} }
if(starts.size() != axes.size() or axes.size() != ends.size()) // For a static shape, old_lens will be adjusted to a new size
// for those axes that are sliced.
// For dynamic shape, the adjusted old_lens become the new max values,
// while updating the old mins and opts if possible.
std::vector<std::size_t> new_mins;
std::vector<std::size_t> new_opts;
std::vector<std::size_t> old_lens;
std::vector<std::size_t> old_strides;
if(input_shape.dynamic())
{ {
MIGRAPHX_THROW("SLICE: inconsistent sizes"); old_lens = input_shape.max_lens();
new_mins = input_shape.min_lens();
new_opts = input_shape.opt_lens();
}
else
{
old_lens = input_shape.lens();
// For static shape (including during eval step after a dynamic input) the strides are
// indexed into the pre-slice array, so they are larger than the apparent size of the
// resulting shape.
old_strides = input_shape.strides();
} }
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
auto axis = axes[i]; auto axis = axes[i];
new_lens[axis] = size_t sliced_length = ends[i] - starts[i];
fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]); // A Numpy indexing convention: a slice size larger than the actual dimension
// is legal and the "ends" value is clipped to the axis size
new_lens[axis] = std::min(new_lens[axis], sliced_length);
if(input_shape.dynamic())
{
// TODO: when non-fixed shape slicing is allowed, this will be different than
// sliced_length, making use of TBD start/end values.
std::size_t sliced_min_length = ends[i] - starts[i];
// if the slice size is smaller than maxes but larger than mins
new_mins[axis] = std::min(sliced_min_length, new_mins[axis]);
auto sliced_opt_length = ends[i] - starts[i];
if(new_opts[axis] != 0)
new_opts[axis] = sliced_opt_length;
if(new_opts[axis] < new_mins[axis] or new_opts[axis] > new_lens[axis])
new_opts[axis] = 0;
}
}
if(input_shape.dynamic())
{
return shape{t, new_mins, new_lens, new_opts};
} }
else
{
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
}
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
return {std::move(output_shape), [=] { return input.data() + offset; }}; auto offset = compute_offset(input.get_shape()) * dyn_out.computed_shape.type_size();
return {dyn_out.computed_shape, [=] { return input.data() + offset; }};
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -46,7 +46,7 @@ struct parse_slice : op_parser<parse_slice>
std::vector<int64_t> steps; std::vector<int64_t> steps;
// slice can have up to 5 inputs, we first check the 5th one // slice can have up to 5 inputs, we first check the 5th one
// to decide whether MIGRAPHX can handle this slice // to decide whether MIGRAPHX can handle this slice.
if(args.size() == 5) if(args.size() == 5)
{ {
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -90,9 +90,10 @@ struct parse_slice : op_parser<parse_slice>
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
} }
// If axes arg is not given, the default is all of them.
if(op.axes.empty()) if(op.axes.empty())
{ {
std::vector<int64_t> axes(args[0]->get_shape().lens().size()); std::vector<int64_t> axes(args[0]->get_shape().ndim());
std::iota(axes.begin(), axes.end(), int64_t{0}); std::iota(axes.begin(), axes.end(), int64_t{0});
op.axes = axes; op.axes = axes;
} }
...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice> ...@@ -103,6 +104,7 @@ struct parse_slice : op_parser<parse_slice>
assert(op.axes.size() == op.starts.size()); assert(op.axes.size() == op.starts.size());
assert(op.axes.size() == op.ends.size()); assert(op.axes.size() == op.ends.size());
// If any axes have negative step, prepare to add a "reverse" op
for(auto i : range(steps.size())) for(auto i : range(steps.size()))
{ {
if(steps[i] >= 0) if(steps[i] >= 0)
...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice> ...@@ -117,7 +119,10 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
{
ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
}
// If any steps are other than default 1, add a "steps" op
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; })) if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{ {
std::vector<int64_t> nsteps; std::vector<int64_t> nsteps;
......
##################################################################################### #####################################################################################
# The MIT License (MIT) # The MIT License (MIT)
# #
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
# #
# Permission is hereby granted, free of charge, to any person obtaining a copy # Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal # of this software and associated documentation files (the "Software"), to deal
...@@ -6184,6 +6184,132 @@ def slice_test(): ...@@ -6184,6 +6184,132 @@ def slice_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test()
def slice_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, None, 2])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, None, 2])
node = onnx.helper.make_node('Slice',
inputs=['0'],
axes=[0],
starts=[1],
ends=[2],
outputs=['1'])
return ([node], [x], [y])
@onnx_test
def slice_step_dyn_test():
# A slice command with non - default steps will have a "Step"
# instruction added in parsing.
step = np.array([2, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-1, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-5, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test
def slice_reverse_dyn_test():
# A slice command with negative step on any axis will have
# a "Reverse" instruction added in parsing.
step = np.array([-1, 1])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-1, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-5, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test() @onnx_test()
def slice_3arg_test(): def slice_3arg_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
......
...@@ -6010,6 +6010,44 @@ TEST_CASE(slice_test) ...@@ -6010,6 +6010,44 @@ TEST_CASE(slice_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0", migraphx::shape{migraphx::shape::float_type, {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}}});
auto ret = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
// Parser converts the dynamic input shape to static unless there is at least one non-fixed
// dynamic dimension. Slicing is not allowed along the non-fixed axis 1.
options.map_dyn_input_dims["0"] = {{3, 3, 0}, {1, 3, 0}, {2, 2, 0}};
auto prog = migraphx::parse_onnx("slice_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_step_dyn_test)
{
// A slice command with non-default steps will have a "Step" instruction added in parsing.
// At the time of writing, Step doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_step_dyn_test.onnx", options); }));
}
TEST_CASE(slice_reverse_dyn_test)
{
// A slice command with negative step on any axis will have a "Reverse" instruction added in
// parsing. At the time of writing, Reverse doesn't support dynamic shape input.
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_reverse_dyn_test.onnx", options); }));
}
TEST_CASE(slice_3arg_test) TEST_CASE(slice_3arg_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -2374,6 +2374,70 @@ TEST_CASE(slice_shape) ...@@ -2374,6 +2374,70 @@ TEST_CASE(slice_shape)
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}}, expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}), migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {10}}}),
input); input);
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 1}, {6, 3, 1}},
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {-1}}, {"ends", {10}}}),
input);
}
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
// Slice axis 1 to size 4-1=3
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {3, 3, 0}, {2, 3, 0}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {4}}}),
input);
}
TEST_CASE(slice_dyn_shape1)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
// Slice axis 1 with negative index
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {2, 3, 0}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {-4}}}),
input);
}
TEST_CASE(slice_dyn_shape2)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 7, 0}, {2, 3, 0}}};
// Sliced range max bigger than dimension; is clipped
expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {6, 6, 0}, {2, 3, 0}}},
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1}}, {"ends", {10}}}),
input);
}
TEST_CASE(slice_dyn_shape3)
{
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3, 0}, {7, 8, 0}, {2, 3, 0}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
input);
// clang-format off
// expect_shape(migraphx::shape{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 3, 0}}},
// migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
// input);
// clang-format on
}
TEST_CASE(slice_dyn_shape4)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
// Slice multiple axes: axis 0 to size 2-1=1 and axis 1 to size 4-1=3
expect_shape(
migraphx::shape{migraphx::shape::int32_type, {{1, 1, 0}, {3, 3, 0}, {2, 3, 0}}},
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
}
TEST_CASE(slice_dyn_shape5)
{
// Axis out of range.
migraphx::shape input{migraphx::shape::int32_type, {{2, 2, 0}, {7, 7, 0}, {2, 3, 0}}};
throws_shape(
migraphx::make_op("slice", {{"axes", {0, 20}}, {"starts", {1, 1}}, {"ends", {2, 4}}}),
input);
} }
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); } TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
......
...@@ -7521,6 +7521,69 @@ TEST_CASE(slice_test) ...@@ -7521,6 +7521,69 @@ TEST_CASE(slice_test)
} }
} }
TEST_CASE(slice_dyn_test0)
{
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is too
// large
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 3, 0}, {2, 2, 0}, {3, 3, 0}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1, 2}}, {"starts", {0, 1}}, {"ends", {1, 6}}}), x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 3, 0}, {1, 1, 0}, {2, 2, 0}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::ref::target{});
// the strides of sresult are those of the original shape, not
// reduced to sliced size.
migraphx::shape sresult{migraphx::shape::int32_type, {2, 1, 2}, {6, 3, 1}};
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 2, 3}};
migraphx::parameter_map params;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 7, 8};
std::vector<int> results_vector(2 * 1 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
TEST_CASE(slice_dyn_test1)
{
// Slice all three dynamic dimensions
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
auto x = mm->add_parameter("x", s);
mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
x);
migraphx::shape s2{migraphx::shape::int32_type, {{2, 2, 0}, {2, 2, 0}, {2, 2, 0}}};
EXPECT(p.get_output_shapes().back() == s2);
p.compile(migraphx::ref::target{});
migraphx::shape sresult{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}};
migraphx::shape input_fixed_shape{migraphx::shape::int32_type, {2, 2, 3}};
migraphx::parameter_map params;
std::vector<int> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
params["x"] = migraphx::argument(input_fixed_shape, data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(result.get_shape() == sresult);
}
TEST_CASE(softmax_simple_test) TEST_CASE(softmax_simple_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment