Commit c7161d99 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

added basic computations; builds but doesn't pass test

parent 9154cbbe
...@@ -18,12 +18,77 @@ namespace migraphx { ...@@ -18,12 +18,77 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
// from parse_resize.cpp
auto& get_nearest_op(const std::string& near_mode)
{
using nearest_op = std::function<std::size_t(std::size_t, double)>;
static std::unordered_map<std::string, nearest_op> const nearest_ops = {
{"round_prefer_floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val - 0.5)));
}},
{"round_prefer_ceil",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::round((val)));
}},
{"floor",
[=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::floor((val)));
}},
{"ceil", [=](std::size_t d_in, double val) {
val = std::max(0.0, std::min(d_in - 1.0, val));
return static_cast<std::size_t>(std::ceil((val)));
}}};
if(not contains(nearest_ops, near_mode))
{
MIGRAPHX_THROW("RESIZE: nearest_mode " + near_mode + " not supported!");
}
return nearest_ops.at(near_mode);
}
const auto& get_original_idx_op(const std::string& mode)
{
using original_idx_op = std::function<double(std::size_t, std::size_t, std::size_t, double)>;
static std::unordered_map<std::string, original_idx_op> const idx_ops = {
{"half_pixel",
[=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale - 0.5;
}},
{"pytorch_half_pixel",
[=](std::size_t, std::size_t l_out, std::size_t idx, double scale) {
return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0;
}},
{"align_corners",
[=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
return (l_out == 1) ? 0.0 : (1.0 * idx * (l_in - 1.0) / (l_out - 1.0));
}},
{"asymmetric",
[=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
{"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) {
return (idx + 0.5) / scale;
}}};
if(not contains(idx_ops, mode))
{
MIGRAPHX_THROW("RESIZE: coordinate_transformation_mode " + mode + " not supported!");
}
return idx_ops.at(mode);
}
struct resize struct resize
{ {
// TODO: indicators. The real scales and sizes are inputs, not attributes. // TODO: indicators. The real scales and sizes are inputs, not attributes.
std::vector<float> scales; std::vector<float> scales;
std::vector<int64_t> sizes; std::vector<int64_t> sizes;
int mode = 0; // 1: nereast 2: bilinear/linear 3: cubic std::string nearest_mode;
int mode = 0; // 1: nearest 2: bilinear/linear 3: cubic
std::string coordinate_transformation_mode; std::string coordinate_transformation_mode;
std::string name() const { return "resize"; } std::string name() const { return "resize"; }
...@@ -33,10 +98,13 @@ struct resize ...@@ -33,10 +98,13 @@ struct resize
{ {
return pack(f(self.scales, "scales"), return pack(f(self.scales, "scales"),
f(self.sizes, "sizes"), f(self.sizes, "sizes"),
f(self.nearest_mode,"nearest_mode"),
f(self.mode,"mode"), f(self.mode,"mode"),
f(self.coordinate_transformation_mode,"coordinate_transformation_mode")); f(self.coordinate_transformation_mode,"coordinate_transformation_mode"));
} }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// check_shapes{{inputs[0]}, *this, true}.has(2); // check_shapes{{inputs[0]}, *this, true}.has(2);
...@@ -77,32 +145,6 @@ struct resize ...@@ -77,32 +145,6 @@ struct resize
return {inputs.front().type(), dyn_dims}; return {inputs.front().type(), dyn_dims};
// static input.
// if(!scales.empty())
// {
// // 计算输出blob大小
// auto in_s = inputs[0];
// auto in_lens = in_s.lens();
// if(in_lens.size() != scales.size())
// {
// MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
// }
// std::vector<std::size_t> out_lens(in_lens.size());
// std::transform(in_lens.begin(),
// in_lens.end(),
// scales.begin(),
// out_lens.begin(),
// [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
// return shape{in_s.type(), out_lens};
// }
// else if(!sizes.empty())
// {
// return shape{inputs[0].type(), sizes};
// }
} }
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
...@@ -110,17 +152,21 @@ struct resize ...@@ -110,17 +152,21 @@ struct resize
// See scatter.hpp or gather.hpp for how to do a similar iteration with reduction // See scatter.hpp or gather.hpp for how to do a similar iteration with reduction
// iterate through items in shape // iterate through items in shape
argument result{dyn_out.computed_shape}; argument result{dyn_out.computed_shape};
// negative axis means counting dimensions from back auto nearest_op = get_nearest_op(nearest_mode);
auto lens = args[0].get_shape().lens(); auto idx_op = get_original_idx_op(coordinate_transformation_mode);
//Everything that follows is placeholder logic
auto axis = 2; auto in_lens = args[0].get_shape().lens();
std::size_t axis_dim_size = lens[axis]; auto out_lens = dyn_out.computed_shape.lens();
// temp. This is a placeholder for reading the desired dimensions or scale
std::vector<double> vec_scale={1., 1., 5./3., 8./3.};
// max dimension in axis // max dimension in axis
visit_all(result, args[0])([&](auto output, auto data) { visit_all(result, args[0])([&](auto output, auto data) {
// the size input // the size input
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
for(auto aa : indices ) std::cout << aa << " indices \n"; for(auto aa : indices ) std::cout << aa << " indices \n";
if(dyn_out.computed_shape.scalar()) if(dyn_out.computed_shape.scalar())
{ {
std::cout << " scalar output\n"; std::cout << " scalar output\n";
...@@ -130,21 +176,29 @@ for(auto aa : indices ) std::cout << aa << " indices \n"; ...@@ -130,21 +176,29 @@ for(auto aa : indices ) std::cout << aa << " indices \n";
// for each element in output, calculate index in input // for each element in output, calculate index in input
for(auto bb : data) std::cout << bb << " zzz data \n"; for(auto bb : data) std::cout << bb << " zzz data \n";
// auto out_lens = data.get_shape().lens();
// out_lens[axis] = indices.get_shape().elements();
migraphx::shape out_comp_shape{data.get_shape().type(), indices}; migraphx::shape out_comp_shape{data.get_shape().type(), indices};
shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) { shape_for_each(out_comp_shape, [&](const auto& out_idx_v, size_t out_idx) {
auto data_idx = out_idx_v; // Show the output indices. Last index iterates fastest
auto in_index = indices[data_idx[axis]]; for(auto vv : out_idx_v ) std::cout << vv << " ";
in_index = (in_index < 0) ? in_index + axis_dim_size : in_index; std::cout <<out_idx << " out_index\n";
data_idx[axis] = in_index; std::cout << nearest_mode << "\n";
output[out_idx] = data(data_idx.begin(), data_idx.end());
std::cout << " !!!!! did something\n"; // populate output at this index
// output[out_idx] = data(data_idx.begin(), data_idx.end());
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < out_idx_v.size(); ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
std::cout << in_lens[ii] << " " << out_lens[ii] << " " << out_idx_v[ii] << " " << vec_scale[ii] << "==> " << idx_val << "\n";
in_idx[ii] = nearest_op(in_lens[ii], idx_val);
}
// output[out_idx] = data.at(in_idx);
}); });
} }
}); });
}); });
std::cout << " finish resize\n";
return result; return result;
} }
......
...@@ -46,9 +46,11 @@ TEST_CASE(resize_test_1) ...@@ -46,9 +46,11 @@ TEST_CASE(resize_test_1)
auto a1 = mm->add_literal(migraphx::literal{size_input, size_values}); auto a1 = mm->add_literal(migraphx::literal{size_input, size_values});
mm->add_instruction(migraphx::make_op("resize", {{"sizes", {1}}, {"scales", {}}}), a0, a1); mm->add_instruction(migraphx::make_op("resize", {{"sizes", {1}}, {"scales", {}}, {"nearest_mode", "floor"}
, {"coordinate_transformation_mode", "half_pixel"}}), a0, a1);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> res_data(4 * 5); std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f}; std::vector<float> golden = {0.5f, 1.5f, 2.5f, 6.5f, 7.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment