"...git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "9cf59e751a3762225b87859fdaea014b89eb2292"
Commit a2534e6c authored by Brian Pickrell's avatar Brian Pickrell
Browse files

first working op, 2 tests

parent 34bb4112
...@@ -105,10 +105,9 @@ struct resize ...@@ -105,10 +105,9 @@ struct resize
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// check_shapes{{inputs[0]}, *this, true}.has(2);
check_shapes{inputs, *this, true}.has(2); check_shapes{inputs, *this, true}.has(2);
// I get to DECIDE what the inputs are. inputs are X, sizes or scale, ROI not supported // Inputs are X, sizes or scale, ROI and axes not supported
if((sizes.empty()) == (scales.empty())) if((sizes.empty()) == (scales.empty()))
MIGRAPHX_THROW("RESIZE: One and only one of max_size or scales attributes must be given"); MIGRAPHX_THROW("RESIZE: One and only one of max_size or scales attributes must be given");
...@@ -119,6 +118,9 @@ struct resize ...@@ -119,6 +118,9 @@ struct resize
if(inputs.front().ndim() != inputs.back().to_static(1).lens()[0]) if(inputs.front().ndim() != inputs.back().to_static(1).lens()[0])
MIGRAPHX_THROW("RESIZE: size/scale input's size must match rank of input X"); MIGRAPHX_THROW("RESIZE: size/scale input's size must match rank of input X");
// TODO: this placeholder logic is for adding another way to tell whether we're
// interpreting second input as sizes or scales.
if(not sizes.empty()) if(not sizes.empty())
{ {
// the second shape is sizes // the second shape is sizes
...@@ -129,10 +131,6 @@ struct resize ...@@ -129,10 +131,6 @@ struct resize
// the second shape is scales // the second shape is scales
} }
// if(std::any_of(
// inputs.cbegin(), inputs.cend(), [](auto input) { return input->get_shape().dynamic(); }))
// {
// }
// No matter what the inputs, the output shape is dynamic, with an unlimited size range. // No matter what the inputs, the output shape is dynamic, with an unlimited size range.
// TODO: How can we tell if the input shape is a literal? If it is, and input X is static, // TODO: How can we tell if the input shape is a literal? If it is, and input X is static,
...@@ -141,8 +139,6 @@ struct resize ...@@ -141,8 +139,6 @@ struct resize
std::vector<shape::dynamic_dimension> dyn_dims(inputs.back().lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(inputs.back().lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
return {inputs.front().type(), dyn_dims}; return {inputs.front().type(), dyn_dims};
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
...@@ -160,36 +156,23 @@ struct resize ...@@ -160,36 +156,23 @@ struct resize
// calculate output shape from scales or sizes // calculate output shape from scales or sizes
if(not sizes.empty()) if(not sizes.empty())
{ {
// read sizes from args[1]
// out_lens = args[1].get_shape().to_static(1).lens(); // <===
// Compute the scales from the given output dimensions
// Copy the output size
args[1].visit([&](auto size_input) { args[1].visit([&](auto size_input) {
for(auto aa : size_input ) std::cout << aa << " sizes \n"; // Copy the output size from args[1]
std::transform(size_input.begin(), size_input.end(), out_lens.begin(), std::transform(size_input.begin(), size_input.end(), out_lens.begin(),
[](auto size_i) { [](auto size_i) {
std::cout << size_i << " transform \n";
return size_i; return size_i;
}); });
std::cout << "***\n";
for(auto aa : out_lens ) std::cout << aa << " out_lens \n";
std::cout << "***\n";
// Deduce the scales for each axis // Deduce the scales for each axis
std::transform(size_input.begin(), size_input.end(), in_lens.begin(), vec_scale.begin(), std::transform(size_input.begin(), size_input.end(), in_lens.begin(), vec_scale.begin(),
[](auto sz, size_t in_len) { [](auto sz, size_t in_len) {
return static_cast<double>(sz)/in_len; return static_cast<double>(sz)/in_len;
}); });
}); });
for(auto aa : vec_scale ) std::cout << aa << " vec_scale \n";
} }
else else
{ {
args[1].visit([&](auto scale_input) { args[1].visit([&](auto scale_input) {
for(auto aa : scale_input ) std::cout << aa << " scale_input \n";
// read the scale from args[1]-- vec_scale = scale_input; // read the scale from args[1]-- vec_scale = scale_input;
// //
std::transform(scale_input.begin(), scale_input.end(), vec_scale.begin(), std::transform(scale_input.begin(), scale_input.end(), vec_scale.begin(),
...@@ -197,7 +180,9 @@ std::cout << "***\n"; ...@@ -197,7 +180,9 @@ std::cout << "***\n";
return scale_i; return scale_i;
}); });
// compute the output dimensions from the given scale // compute the output dimensions from the given scales. This computation
// always rounds down, unlike the internal computation in Nearest mode
// which has several options as given in nearest_mode.
std::transform(scale_input.begin(), scale_input.end(), in_lens.begin(), out_lens.begin(), std::transform(scale_input.begin(), scale_input.end(), in_lens.begin(), out_lens.begin(),
[](auto scale_i, size_t in_len) { [](auto scale_i, size_t in_len) {
return static_cast<size_t>(scale_i*in_len); return static_cast<size_t>(scale_i*in_len);
...@@ -211,45 +196,20 @@ std::cout << "***\n"; ...@@ -211,45 +196,20 @@ std::cout << "***\n";
auto nearest_op = get_nearest_op(nearest_mode); auto nearest_op = get_nearest_op(nearest_mode);
auto idx_op = get_original_idx_op(coordinate_transformation_mode); auto idx_op = get_original_idx_op(coordinate_transformation_mode);
// temp. This is a placeholder for reading the desired dimensions or scale // Populate each element in output by selecting "nearest" item in input.
// max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
// the size input shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) {
// args[1].visit([&](auto indices) { std::vector<size_t> in_idx(out_idx_v.size());
// for(auto aa : indices ) std::cout << aa << " indices \n"; for(auto ii = 0; ii < out_idx_v.size(); ++ii)
// if(dyn_out.computed_shape.scalar()) {
// { auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
// std::cout << " scalar output\n"; in_idx[ii] = nearest_op(in_lens[ii], idx_val);
// } }
// else // TODO: use index function instead?
// { output[out_idx] = data(in_idx.begin(), in_idx.end());
// for each element in output, calculate index in input });
for(auto bb : data) std::cout << bb << " zzz data \n";
migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) {
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < out_idx_v.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
std::cout << "\n";
std::cout <<out_idx << " out_index\n";
auto zap = data(in_idx.begin(), in_idx.end());
for(auto gg : output) std::cout << gg << " "; std::cout <<"ggg\n";
// use index function instead?
output[out_idx] = data(in_idx.begin(), in_idx.end());
std::cout << zap << "\n";
});
// }
// });
}); });
std::cout << " finish resize\n";
return result; return result;
} }
......
...@@ -54,7 +54,53 @@ TEST_CASE(resize_test_1) ...@@ -54,7 +54,53 @@ TEST_CASE(resize_test_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> res_data(1*1*5*8); std::vector<float> res_data(1*1*5*8);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; std::vector<float> golden = {0.5f, 0.5f, 0.5f, 0.5f, 1.5f, 1.5f, 1.5f, 2.5f,
0.5f, 0.5f, 0.5f, 0.5f, 1.5f, 1.5f, 1.5f, 2.5f,
3.5f, 3.5f, 3.5f, 3.5f, 4.5f, 4.5f, 4.5f, 5.5f,
3.5f, 3.5f, 3.5f, 3.5f, 4.5f, 4.5f, 4.5f, 5.5f,
6.5f, 6.5f, 6.5f, 6.5f, 7.5f, 7.5f, 7.5f, 8.5};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
for(auto aa : res_data) std::cout << aa << ", "; std::cout << " result \n";
EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
}
TEST_CASE(resize_upsample_test_2)
{
// batch size 2, 1 color channel, resize 3x5 by 1.6x
// same input/output as resize_upsample_f_dyn_test
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2*3*5);
std::iota(data.begin(), data.end(), 0.1);
// should upscale to 2x1x4x8
migraphx::shape s{migraphx::shape::float_type, {2, 1, 3, 5}};
// to do: non-literal
auto a0 = mm->add_literal(migraphx::literal{s, data});
// scale input
migraphx::shape scale_input{migraphx::shape::float_type, {4}};
std::vector<float> scale_values = {1.0, 1.0, 1.601, 1.601};
auto a1 = mm->add_literal(migraphx::literal{scale_input, scale_values});
// a0 = input data
// a1 = scales
mm->add_instruction(migraphx::make_op("resize", {{"sizes", {}}, {"scales", {1}}, {"nearest_mode", "round_prefer_ceil"}
, {"coordinate_transformation_mode", "half_pixel"}}), a0, a1);
p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back();
std::vector<float> res_data(2*1*4*8);
// clang-format off
std::vector<float> golden = {
0.1f, 0.1f, 1.1f, 2.1f, 2.1f, 3.1f, 4.1f, 4.1f,
0.1f, 0.1f, 1.1f, 2.1f, 2.1f, 3.1f, 4.1f, 4.1f,
5.1f, 5.1f, 6.1f, 7.1f, 7.1f, 8.1f, 9.1f, 9.1f,
10.1f, 10.1f, 11.1f, 12.1f, 12.1f, 13.1f, 14.1f, 14.1f,
15.1f, 15.1f, 16.1f, 17.1f, 17.1f, 18.1f, 19.1f, 19.1f,
15.1f, 15.1f, 16.1f, 17.1f, 17.1f, 18.1f, 19.1f, 19.1f,
20.1f, 20.1f, 21.1f, 22.1f, 22.1f, 23.1f, 24.1f, 24.1f,
25.1f, 25.1f, 26.1f, 27.1f, 27.1f, 28.1f, 29.1f, 29.1f};
// clang-format on
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
for(auto aa : res_data) std::cout << aa << ", "; std::cout << " result \n"; for(auto aa : res_data) std::cout << aa << ", "; std::cout << " result \n";
EXPECT(migraphx::verify::verify_rms_range(res_data, golden)); EXPECT(migraphx::verify::verify_rms_range(res_data, golden));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment