Unverified Commit 4789b387 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Upsample op (#646)



* code backup for upsample op

* clang format

* fixed a bug

* fix a bug

* clang format

* add unit tests for upsample

* clang format

* clang format
Co-authored-by: default avatarPaul Fultz II <pfultz2@yahoo.com>
parent 3446bea5
...@@ -172,6 +172,7 @@ struct onnx_parser ...@@ -172,6 +172,7 @@ struct onnx_parser
add_mem_op("Split", &onnx_parser::parse_split); add_mem_op("Split", &onnx_parser::parse_split);
add_mem_op("Tile", &onnx_parser::parse_tile); add_mem_op("Tile", &onnx_parser::parse_tile);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Upsample", &onnx_parser::parse_upsample);
add_mem_op("Where", &onnx_parser::parse_where); add_mem_op("Where", &onnx_parser::parse_where);
// init the activation function map // init the activation function map
...@@ -2486,6 +2487,71 @@ struct onnx_parser ...@@ -2486,6 +2487,71 @@ struct onnx_parser
return l; return l;
} }
instruction_ref
parse_upsample(const std::string&, const node_info& info, std::vector<instruction_ref> args)
{
if(contains(info.attributes, "mode"))
{
auto mode = info.attributes.at("mode").s();
if(mode != "nearest")
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
}
}
auto arg_scale = args[1]->eval();
check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
std::vector<float> vec_scale;
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
if(in_lens.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
}
std::vector<std::size_t> out_lens(in_lens.size());
std::transform(in_lens.begin(),
in_lens.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
std::vector<float> idx_scale(in_lens.size());
std::transform(
out_lens.begin(),
out_lens.end(),
in_lens.begin(),
idx_scale.begin(),
[](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });
shape out_s{in_s.type(), out_lens};
std::vector<int> ind(out_s.elements());
// map out_idx to in_idx
shape_for_each(out_s, [&](auto idx) {
auto in_idx = idx;
std::transform(idx.begin(),
idx.end(),
idx_scale.begin(),
in_idx.begin(),
// nearest mode
[](auto index, auto scale) {
return static_cast<std::size_t>(std::round(index * scale));
});
ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
});
// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
shape ind_s{shape::int32_type, out_lens};
auto rsp = prog.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
auto ins_ind = prog.add_literal(literal(ind_s, ind));
return prog.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
}
instruction_ref instruction_ref
parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args) parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args)
{ {
......
...@@ -3021,6 +3021,27 @@ def unknown_aten_test(): ...@@ -3021,6 +3021,27 @@ def unknown_aten_test():
return ([node], [x, y], [a]) return ([node], [x, y], [a])
@onnx_test
def upsample_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6])
node = onnx.helper.make_node(
'Upsample',
inputs=['X', 'scales'],
outputs=['Y'],
mode='nearest',
)
return ([node], [X], [Y], [scale_tensor])
@onnx_test @onnx_test
def variable_batch_test(): def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
......
...@@ -2187,6 +2187,28 @@ TEST_CASE(unknown_test_throw) ...@@ -2187,6 +2187,28 @@ TEST_CASE(unknown_test_throw)
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); }));
} }
TEST_CASE(upsample_test)
{
migraphx::program p;
migraphx::shape ss{migraphx::shape::float_type, {4}};
p.add_literal(migraphx::literal(ss, {1.0f, 1.0f, 2.0f, 3.0f}));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = p.add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = p.add_literal(migraphx::literal(si, ind));
auto rsp = p.add_instruction(migraphx::op::reshape{{4}}, ix);
auto r = p.add_instruction(migraphx::op::gather{0}, rsp, li);
p.add_return({r});
auto prog = migraphx::parse_onnx("upsample_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw_print_error) TEST_CASE(unknown_test_throw_print_error)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
......
...@@ -127,6 +127,25 @@ TEST_CASE(gather_elements) ...@@ -127,6 +127,25 @@ TEST_CASE(gather_elements)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(upsample_test)
{
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
std::vector<float> x_data = {1, 2, 3, 4};
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
migraphx::program::parameter_map pp;
pp["X"] = migraphx::argument(sx, x_data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2,
3, 3, 3, 4, 4, 4, 3, 3, 3, 4, 4, 4};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(selu_test) TEST_CASE(selu_test)
{ {
migraphx::program p = migraphx::parse_onnx("selu_test.onnx"); migraphx::program p = migraphx::parse_onnx("selu_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment