Unverified Commit 521b57a2 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Topk op (#877)



* add topk operator doe ref, cpu and gpu
* Hash modules for quicker lookup of modules
* add onnx unit test
* add unit tests for the topk operator
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 1f741f73
...@@ -166,6 +166,7 @@ register_migraphx_ops( ...@@ -166,6 +166,7 @@ register_migraphx_ops(
sub sub
tanh tanh
tan tan
topk
transpose transpose
unary_not unary_not
undefined undefined
......
File mode changed from 100755 to 100644
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -13,7 +14,7 @@ void eliminate_data_type::apply(module& m) const ...@@ -13,7 +14,7 @@ void eliminate_data_type::apply(module& m) const
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(ins->name() == "convert") if(contains({"convert", "get_tuple_elem"}, ins->name()))
continue; continue;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
......
...@@ -6,17 +6,42 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,17 +6,42 @@ inline namespace MIGRAPHX_INLINE_NS {
argument fill_argument(shape s, unsigned long value) argument fill_argument(shape s, unsigned long value)
{ {
argument result; argument result;
if(s.type() == shape::tuple_type)
{
std::vector<argument> sub_args;
const auto& sub_ss = s.sub_shapes();
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
return fill_argument(ss, value);
});
result = argument(sub_args);
}
else
{
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
using type = typename decltype(as)::type; using type = typename decltype(as)::type;
auto v = fill_tensor_data<type>(s, value); auto v = fill_tensor_data<type>(s, value);
result = {s, v}; result = {s, v};
}); });
}
return result; return result;
} }
argument generate_argument(shape s, unsigned long seed) argument generate_argument(shape s, unsigned long seed)
{ {
argument result; argument result;
if(s.type() == shape::tuple_type)
{
const auto& sub_ss = s.sub_shapes();
std::vector<argument> sub_args;
std::transform(sub_ss.begin(), sub_ss.end(), std::back_inserter(sub_args), [&](auto ss) {
return generate_argument(ss, seed);
});
result = argument(sub_args);
}
else
{
s.visit_type([&](auto as) { s.visit_type([&](auto as) {
// we use char type to store bool type internally, so bool_type // we use char type to store bool type internally, so bool_type
// needs special processing to generate data // needs special processing to generate data
...@@ -32,6 +57,8 @@ argument generate_argument(shape s, unsigned long seed) ...@@ -32,6 +57,8 @@ argument generate_argument(shape s, unsigned long seed)
result = {s, v}; result = {s, v};
} }
}); });
}
return result; return result;
} }
......
...@@ -45,6 +45,8 @@ struct get_tuple_elem ...@@ -45,6 +45,8 @@ struct get_tuple_elem
assert(index < vec_args.size()); assert(index < vec_args.size());
return vec_args.at(index); return vec_args.at(index);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "topk"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return shape({s_val, s_ind});
}
template <class T, class Compare>
struct heap_vector
{
std::vector<T> data;
Compare compare;
heap_vector(const std::vector<T>& val, Compare comp) : data(val), compare(std::move(comp))
{
std::make_heap(data.begin(), data.end(), compare);
}
void try_push(T val)
{
if(not compare(val, data.front()))
return;
std::pop_heap(data.begin(), data.end(), compare);
data.back() = val;
std::push_heap(data.begin(), data.end(), compare);
}
std::vector<T> sort()
{
auto sorted_data = data;
std::sort_heap(sorted_data.begin(), sorted_data.end(), compare);
return sorted_data;
}
};
template <class T, class Compare>
heap_vector<T, Compare> make_heap(std::vector<T> val, Compare compare) const
{
return {std::move(val), std::move(compare)};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
auto out_s = vec_ss.front();
auto comp_lens = in_s.lens();
auto axis_dim = comp_lens[axis];
// compute shape
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
visit_all(res_val, args.front())([&](auto out_val, auto input) {
auto* out_ind = res_ind.cast<int64_t>();
par_for(comp_s.elements(), [&](auto i) {
auto idx = comp_s.multi(i);
std::vector<std::size_t> indices(k);
std::iota(indices.begin(), indices.end(), 0);
auto comp = [&](auto i1, auto i2) {
auto idx1 = idx;
auto idx2 = idx;
idx1[axis] = i1;
idx2[axis] = i2;
return this->largest
? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)])
: std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]);
};
auto hp = this->make_heap(indices, comp);
for(std::size_t ii = indices.size(); ii < axis_dim; ++ii)
{
hp.try_push(ii);
}
auto sorted_indices = hp.sort();
auto out_idx = idx;
auto in_idx = idx;
for(auto j : range(sorted_indices.size()))
{
out_idx[axis] = j;
in_idx[axis] = sorted_indices[j];
out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)];
out_ind[out_s.index(out_idx)] = sorted_indices[j];
}
});
});
return argument({res_val, res_ind});
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -95,6 +95,7 @@ ...@@ -95,6 +95,7 @@
#include <migraphx/op/sub.hpp> #include <migraphx/op/sub.hpp>
#include <migraphx/op/tanh.hpp> #include <migraphx/op/tanh.hpp>
#include <migraphx/op/tan.hpp> #include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp> #include <migraphx/op/unary_not.hpp>
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_topk : op_parser<parse_topk>
{
std::vector<op_desc> operators() const { return {{"TopK"}}; }
std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
int64_t k = 0;
if(args.size() == 2)
{
auto arg_k = args.at(1)->eval();
check_arg_empty(arg_k, "PARSE_TopK: k input must be constant");
k = arg_k.at<int>();
}
else if(contains(info.attributes, "k"))
{
k = info.attributes.at("k").i();
}
bool largest = true;
if(contains(info.attributes, "largest"))
{
largest = static_cast<bool>(info.attributes.at("largest").i());
}
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int>();
}
auto topk_ret = info.add_instruction(
make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0));
auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret);
auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret);
return {ret_val, ret_ind};
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval) ...@@ -40,7 +40,7 @@ bool memory_coloring_impl::allocate(interval_ptr interval)
std::size_t size = s.bytes(); std::size_t size = s.bytes();
if(size == 0) if(size == 0)
return false; return false;
std::size_t element_size = size / s.elements(); std::size_t element_size = (s.elements() == 0 ? 4 : (size / s.elements()));
live_range& segment = interval->segment; live_range& segment = interval->segment;
int vn = segment.vn; int vn = segment.vn;
std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue; std::priority_queue<live_range*, std::vector<live_range*>, ordering> conflict_queue;
......
...@@ -84,6 +84,7 @@ add_library(migraphx_device ...@@ -84,6 +84,7 @@ add_library(migraphx_device
device/sub.cpp device/sub.cpp
device/tan.cpp device/tan.cpp
device/tanh.cpp device/tanh.cpp
device/topk.cpp
device/unary_not.cpp device/unary_not.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
...@@ -152,6 +153,7 @@ add_library(migraphx_gpu ...@@ -152,6 +153,7 @@ add_library(migraphx_gpu
softmax.cpp softmax.cpp
sync_device.cpp sync_device.cpp
target.cpp target.cpp
topk.cpp
write_literals.cpp write_literals.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
...@@ -218,6 +220,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -218,6 +220,7 @@ register_migraphx_gpu_ops(hip_
sub sub
tanh tanh
tan tan
topk
unary_not unary_not
) )
register_migraphx_gpu_ops(miopen_ register_migraphx_gpu_ops(miopen_
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/topk.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, class Index, class Compare>
struct hip_heap_vector
{
MIGRAPHX_DEVICE_CONSTEXPR hip_heap_vector(T* val, index_int n, Index v_idx, Compare comp)
: data(val), size(n), data_index(v_idx), compare(comp)
{
make_heap(size);
}
MIGRAPHX_DEVICE_CONSTEXPR void try_push(const T val)
{
if(compare(val, data[data_index(0)]))
return;
pop_heap(size - 1);
data[data_index(size - 1)] = val;
push_heap(size - 1);
}
MIGRAPHX_DEVICE_CONSTEXPR void sort() { sort_heap(size); }
private:
MIGRAPHX_DEVICE_CONSTEXPR inline static void swap(T& v1, T& v2)
{
T v = v1;
v1 = v2;
v2 = v;
}
MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_down(index_int n, index_int index)
{
while(index < n)
{
auto pre_index = index;
index_int l = 2 * index + 1;
index_int r = 2 * index + 2;
if(l < n && compare(data[data_index(l)], data[data_index(index)]))
{
index = l;
}
if(r < n && compare(data[data_index(r)], data[data_index(index)]))
{
index = r;
if(compare(data[data_index(l)], data[data_index(r)]))
{
index = l;
}
}
if(index == pre_index)
{
break;
}
swap(data[data_index(index)], data[data_index(pre_index)]);
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_up(index_int index)
{
while(index > 0)
{
auto parent_idx = (index - 1) / 2;
if(not compare(data[data_index(index)], data[data_index(parent_idx)]))
{
break;
}
swap(data[data_index(index)], data[data_index(parent_idx)]);
index = parent_idx;
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void make_heap(index_int n)
{
for(int j = n / 2 - 1; j >= 0; --j)
{
heapify_down(n, j);
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void push_heap(index_int loc) { heapify_up(loc); }
MIGRAPHX_DEVICE_CONSTEXPR inline void pop_heap(index_int loc)
{
swap(data[data_index(0)], data[data_index(loc)]);
heapify_down(loc, 0);
}
MIGRAPHX_DEVICE_CONSTEXPR inline void sort_heap(index_int n)
{
for(int j = n - 1; j > 0; --j)
{
swap(data[data_index(0)], data[data_index(j)]);
heapify_down(j, 0);
}
}
T* data = nullptr;
index_int size;
Index data_index;
Compare compare;
};
template <class T, class Index, class Compare>
__device__ hip_heap_vector<T, Index, Compare>
make_heap(T* data, index_int n, Index idx, Compare compare)
{
return {data, n, idx, compare};
}
template <class Compare>
std::vector<argument> topk(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis,
Compare compare)
{
auto in_s = arg.get_shape();
auto in_lens = in_s.lens();
auto out_s = val_res.get_shape();
auto axis_dim = in_s.lens()[axis];
auto comp_lens = in_lens;
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
std::size_t elem_num = comp_s.elements();
hip_visit_all(val_res, arg, out_s, in_s, comp_s)(
[&](auto out_val, auto input, auto oss, auto iss, auto css) {
auto* data = device_cast(input.data());
auto* out = device_cast(out_val.data());
auto* const ind = ind_res.cast<int64_t>();
gs_launch(stream, elem_num)([=](auto i) __device__ {
auto idx = css.multi(i);
auto in_idx = [&](int ii) {
auto iidx = idx;
iidx[axis] = ii;
return iss.index(iidx);
};
auto out_idx = [&](int ii) {
auto iidx = idx;
iidx[axis] = ii;
return oss.index(iidx);
};
auto data_compare = [=](auto ii, auto jj) {
return compare(data[in_idx(ii)], data[in_idx(jj)]);
};
for(int j = 0; j < k; ++j)
{
ind[out_idx(j)] = j;
}
auto hp = make_heap(ind, k, out_idx, data_compare);
for(int j = k; j < axis_dim; ++j)
{
hp.try_push(j);
}
hp.sort();
for(int j = 0; j < k; ++j)
{
out[out_idx(j)] = data[in_idx(ind[out_idx(j)])];
}
});
});
return {val_res, ind_res};
}
argument topk_largest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis)
{
return {topk(stream, val_res, ind_res, arg, k, axis, std::less<>{})};
}
argument topk_smallest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis)
{
return {topk(stream, val_res, ind_res, arg, k, axis, std::greater<>{})};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_TOPK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_TOPK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument topk_smallest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis);
argument topk_largest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_TOPK_HPP
#define MIGRAPHX_GUARD_RTGLIB_TOPK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_topk
{
op::topk op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::topk"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <iterator>
#include <migraphx/gpu/lowering.hpp> #include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -175,6 +176,7 @@ struct miopen_apply ...@@ -175,6 +176,7 @@ struct miopen_apply
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
add_extend_op("scatter"); add_extend_op("scatter");
add_extend_op("softmax"); add_extend_op("softmax");
add_extend_op("topk");
add_gemm_op<op::dot>("dot"); add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot"); add_gemm_op<op::quant_dot>("quant_dot");
...@@ -426,7 +428,7 @@ struct miopen_apply ...@@ -426,7 +428,7 @@ struct miopen_apply
}); });
} }
// replace the if operator with gpu_if operator // add input and output argument for the if operator
void add_if_op() void add_if_op()
{ {
apply_map.emplace("if", [=](instruction_ref ins) { apply_map.emplace("if", [=](instruction_ref ins) {
......
#include <migraphx/gpu/topk.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/topk.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_topk::compute_shape(std::vector<shape> inputs) const
{
return op.normalize_compute_shape({inputs.front()});
}
argument hip_topk::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto outputs = args.back().get_sub_objects();
return op.largest ? device::topk_largest(ctx.get_stream().get(),
outputs.front(),
outputs.back(),
args[0],
op.k,
op.axis)
: device::topk_smallest(ctx.get_stream().get(),
outputs.front(),
outputs.back(),
args[0],
op.k,
op.axis);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -8,4 +8,34 @@ TEST_CASE(generate) ...@@ -8,4 +8,34 @@ TEST_CASE(generate)
EXPECT(migraphx::generate_literal(s, 1) != migraphx::generate_argument(s, 0)); EXPECT(migraphx::generate_literal(s, 1) != migraphx::generate_argument(s, 0));
} }
TEST_CASE(fill_tuple)
{
migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::bool_type, {3, 2}};
migraphx::shape s({s0, s1, s2});
auto arg = migraphx::fill_argument(s, 1);
const auto& args = arg.get_sub_objects();
EXPECT(args.at(0) == migraphx::fill_argument(s0, 1));
EXPECT(args.at(1) == migraphx::fill_argument(s1, 1));
EXPECT(args.at(2) == migraphx::fill_argument(s2, 1));
}
TEST_CASE(generate_tuple)
{
migraphx::shape s0{migraphx::shape::float_type, {4, 4, 1, 1}};
migraphx::shape s1{migraphx::shape::int32_type, {2, 3}};
migraphx::shape s2{migraphx::shape::bool_type, {3, 2}};
migraphx::shape s({s0, s1, s2});
auto arg = migraphx::generate_argument(s, 1);
const auto& args = arg.get_sub_objects();
EXPECT(args.at(0) == migraphx::generate_argument(s0, 1));
EXPECT(args.at(1) == migraphx::generate_argument(s1, 1));
EXPECT(args.at(2) == migraphx::generate_argument(s2, 1));
EXPECT(args.at(0) != migraphx::generate_argument(s0, 0));
EXPECT(args.at(1) != migraphx::generate_argument(s1, 2));
EXPECT(args.at(2) != migraphx::generate_argument(s2, 0));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -4151,6 +4151,61 @@ def tile_test_3x2(): ...@@ -4151,6 +4151,61 @@ def tile_test_3x2():
@onnx_test @onnx_test
def topk_attrk_test():
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 2, 3, 2])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 2, 3, 2])
node = onnx.helper.make_node('TopK',
inputs=['data'],
outputs=['val', 'indices'],
k=2)
return ([node], [x], [val, ind])
@onnx_test
def topk_neg_axis_test():
k = np.array([3])
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3, 4, 5, 6])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [3, 3, 5, 6])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[3, 3, 5, 6])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('TopK',
inputs=['data', 'k'],
outputs=['val', 'indices'],
axis=-2,
sorted=0)
return ([node], [x], [val, ind], [k_tensor])
@onnx_test
def topk_test():
k = np.array([4])
x = helper.make_tensor_value_info('data', TensorProto.FLOAT, [2, 5, 3, 2])
val = helper.make_tensor_value_info('val', TensorProto.FLOAT, [2, 4, 3, 2])
ind = helper.make_tensor_value_info('indices', TensorProto.INT64,
[2, 4, 3, 2])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('TopK',
inputs=['data', 'k'],
outputs=['val', 'indices'],
largest=0,
axis=1)
return ([node], [x], [val, ind], [k_tensor])
def transpose_default_perm_test(): def transpose_default_perm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 5, 2, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 5, 2, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 5, 1]) y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 5, 1])
......
...@@ -3842,6 +3842,60 @@ TEST_CASE(transpose_test) ...@@ -3842,6 +3842,60 @@ TEST_CASE(transpose_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(topk_attrk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(migraphx::make_op("topk", {{"k", 2}, {"axis", -1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_attrk_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(topk_neg_axis_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {3}));
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 3}, {"axis", -2}, {"largest", 1}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_neg_axis_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(topk_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sk{migraphx::shape::int64_type, {1}};
mm->add_literal(migraphx::literal(sk, {4}));
migraphx::shape s{migraphx::shape::float_type, {2, 5, 3, 2}};
auto data = mm->add_parameter("data", s);
auto out = mm->add_instruction(
migraphx::make_op("topk", {{"k", 4}, {"axis", 1}, {"largest", 0}}), data);
auto val = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), out);
auto ind = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), out);
mm->add_return({val, ind});
auto prog = migraphx::parse_onnx("topk_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(transpose_gather_test) TEST_CASE(transpose_gather_test)
{ {
migraphx::program p; migraphx::program p;
......
topk_attrk_test:
$
datavalindices"TopK*
ktopk_attrk_testZ
data




b
val




b!
indices




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment