Commit 8343d47d authored by charlie's avatar charlie
Browse files

Initial

parent 9550f6e9
......@@ -39,9 +39,10 @@ namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
bool use_dynamic = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -60,15 +61,40 @@ struct topk
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
if(use_dynamic)
{
// input tensor and k value
check_shapes{inputs, *this, true}.has(2);
auto type = inputs.at(0).type();
std::vector<shape::dynamic_dimension> dyn_dims;
if(inputs.at(0).dynamic())
{
dyn_dims = inputs.at(0).dyn_dims();
}
else
{
auto dyn_shape = fixed_to_dynamic(inputs.at(0));
dyn_dims = dyn_shape.dyn_dims();
}
shape s_val{type, dyn_dims};
shape s_ind{shape::int64_type, dyn_dims};
return {{s_val, s_ind}};
}
else
{
return fixed_compute_shape(inputs, this->k);
}
}
shape fixed_compute_shape(std::vector<shape> inputs, int64_t in_k) const
{
// TODO: is standard layout needed here? No, it should not be needed.
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = in_k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return {{s_val, s_ind}};
}
......@@ -109,7 +135,10 @@ struct topk
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
auto vec_ss =
use_dynamic ? fixed_compute_shape({args.front().get_shape()}, args.at(1).at<int64_t>())
.sub_shapes()
: output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
......
......@@ -332,6 +332,11 @@ struct shape
void migraphx_to_value(value& v, const shape& s);
void migraphx_from_value(const value& v, shape& s);
/*!
* Make a dynamic shape from a fixed shape.
*/
shape fixed_to_dynamic(const shape& s);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -577,5 +577,15 @@ void migraphx_from_value(const value& v, shape& s)
}
}
shape fixed_to_dynamic(const shape& s)
{
auto fixed_lens = s.lens();
std::vector<shape::dynamic_dimension> dyn_dims;
std::transform(fixed_lens.cbegin(), fixed_lens.cend(), dyn_dims.begin(), [](auto l) {
return shape::dynamic_dimension{l, l, 0};
});
return {s.type(), dyn_dims};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment