"examples/vscode:/vscode.git/clone" did not exist on "c5bbc608555014e7b0b7108bc0a5eebb1522768f"
Commit 4ed53bc6 authored by charlie's avatar charlie
Browse files

Initial

parent f888f611
...@@ -29,11 +29,20 @@ ...@@ -29,11 +29,20 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp> #include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace op { namespace op {
/**
* Adds dimensions to a tensor based on the axes attribute.
* `axes` are based on the number of output shape dimensions and should not contain duplicates.
* `steps` are for modifying dimensions added to the middle of the original shape.
* Each step must be a factor of the original dimension.
* ex: unsqueeze(shape = [3, 4, 10], axes = [2, 4, 5], steps = [2]) -> shape = [3, 4, 2, 5, 1, 1]
* Dynamic shape version does not handle `steps`.
*/
struct unsqueeze struct unsqueeze
{ {
std::vector<int64_t> axes; std::vector<int64_t> axes;
...@@ -56,63 +65,90 @@ struct unsqueeze ...@@ -56,63 +65,90 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this, true}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens(); if(input_shape.dynamic())
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{ {
if(old_lens.size() == 1 and old_lens.front() == 1) if(not steps.empty())
return shape{type, old_lens}; {
else MIGRAPHX_THROW("UNSQUEEZE_dyn: nonempty steps attribute");
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); }
std::vector<shape::dynamic_dimension> dyn_dims = {};
auto new_ndim = input_shape.ndim() + axes.size();
std::size_t k = 0;
for(auto i : range(new_ndim))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{
dyn_dims.push_back({1, 1, 0});
}
else
{
dyn_dims.push_back(input_shape.dyn_dims().at(k++));
}
}
return {input_shape.type(), dyn_dims};
} }
else
{
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
if(steps.size() > axes.size()) if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis"); MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
std::vector<std::size_t> new_strides(new_size); std::vector<std::size_t> new_strides(new_size);
std::size_t p = 0; std::size_t p = 0;
for(auto i : range(new_size)) for(auto i : range(new_size))
{
auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < axes.size())
{ {
std::int64_t step = 1; auto axis_idx = std::find(axes.begin(), axes.end(), i) - axes.begin();
if(axis_idx < steps.size()) if(axis_idx < axes.size())
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{ {
if((old_lens[p] % step) != 0) std::int64_t step = 1;
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step"); if(axis_idx < steps.size())
old_lens[p] /= step; step = steps[axis_idx];
new_strides[i] = old_strides[p] * old_lens[p]; if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
}
else
{
if(step != 1)
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes");
new_strides[i] = 1;
}
} }
else else
{ {
if(step != 1) new_lens[i] = old_lens[p];
MIGRAPHX_THROW("UNSQUEEZE: Step must be 1 for extra axes"); new_strides[i] = old_strides[p++];
new_strides[i] = 1;
} }
} }
else return shape{type, new_lens, new_strides};
{
new_lens[i] = old_lens[p];
new_strides[i] = old_strides[p++];
}
} }
return shape{type, new_lens, new_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(dyn_out.computed_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment