Commit 8343d47d authored by charlie's avatar charlie
Browse files

Initial

parent 9550f6e9
...@@ -42,6 +42,7 @@ struct topk ...@@ -42,6 +42,7 @@ struct topk
int64_t k = 1; int64_t k = 1;
int64_t axis = 0; int64_t axis = 0;
bool largest = true; bool largest = true;
bool use_dynamic = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -60,15 +61,40 @@ struct topk ...@@ -60,15 +61,40 @@ struct topk
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
if(use_dynamic)
{
// input tensor and k value
check_shapes{inputs, *this, true}.has(2);
auto type = inputs.at(0).type();
std::vector<shape::dynamic_dimension> dyn_dims;
if(inputs.at(0).dynamic())
{
dyn_dims = inputs.at(0).dyn_dims();
}
else
{
auto dyn_shape = fixed_to_dynamic(inputs.at(0));
dyn_dims = dyn_shape.dyn_dims();
}
shape s_val{type, dyn_dims};
shape s_ind{shape::int64_type, dyn_dims};
return {{s_val, s_ind}};
}
else
{
return fixed_compute_shape(inputs, this->k);
}
}
shape fixed_compute_shape(std::vector<shape> inputs, int64_t in_k) const
{
// TODO: is standard layout needed here? No, it should not be needed.
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens(); auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type(); auto type = inputs.at(0).type();
lens[axis] = in_k;
lens[axis] = k;
shape s_val{type, lens}; shape s_val{type, lens};
shape s_ind{shape::int64_type, lens}; shape s_ind{shape::int64_type, lens};
return {{s_val, s_ind}}; return {{s_val, s_ind}};
} }
...@@ -109,7 +135,10 @@ struct topk ...@@ -109,7 +135,10 @@ struct topk
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
auto vec_ss = output_shape.sub_shapes(); auto vec_ss =
use_dynamic ? fixed_compute_shape({args.front().get_shape()}, args.at(1).at<int64_t>())
.sub_shapes()
: output_shape.sub_shapes();
argument res_val{vec_ss.front()}; argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()}; argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape(); auto in_s = args.front().get_shape();
......
...@@ -332,6 +332,11 @@ struct shape ...@@ -332,6 +332,11 @@ struct shape
void migraphx_to_value(value& v, const shape& s); void migraphx_to_value(value& v, const shape& s);
void migraphx_from_value(const value& v, shape& s); void migraphx_from_value(const value& v, shape& s);
/*!
* Make a dynamic shape from a fixed shape.
*/
shape fixed_to_dynamic(const shape& s);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -577,5 +577,15 @@ void migraphx_from_value(const value& v, shape& s) ...@@ -577,5 +577,15 @@ void migraphx_from_value(const value& v, shape& s)
} }
} }
shape fixed_to_dynamic(const shape& s)
{
auto fixed_lens = s.lens();
std::vector<shape::dynamic_dimension> dyn_dims;
std::transform(fixed_lens.cbegin(), fixed_lens.cend(), dyn_dims.begin(), [](auto l) {
return shape::dynamic_dimension{l, l, 0};
});
return {s.type(), dyn_dims};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment