Commit b6a9b597 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rnn_operator

parents 128b0b65 b5090737
...@@ -117,7 +117,9 @@ rocm_enable_cppcheck( ...@@ -117,7 +117,9 @@ rocm_enable_cppcheck(
passedByValue passedByValue
unusedStructMember unusedStructMember
functionStatic functionStatic
functionConst functionConst:*program.*
shadowFunction
shadowVar
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
FORCE FORCE
INCONCLUSIVE INCONCLUSIVE
......
...@@ -8,10 +8,16 @@ def rocmtestnode(variant, name, body) { ...@@ -8,10 +8,16 @@ def rocmtestnode(variant, name, body) {
mkdir build mkdir build
cd build cd build
CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake ${flags} .. CXX=${compiler} CXXFLAGS='-Werror -Wno-fallback' cmake ${flags} ..
CTEST_PARALLEL_LEVEL=32 make -j32 generate all doc check CTEST_PARALLEL_LEVEL=32 make -j32 generate all doc package check
""" """
echo cmd echo cmd
sh cmd sh cmd
if (compiler == "hcc") {
// Only archive from master or develop
if (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master") {
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
}
}
} }
node(name) { node(name) {
stage("checkout ${variant}") { stage("checkout ${variant}") {
......
ignore:
- "test/"
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
<message> <message>
<id>definePrefix</id> <id>definePrefix</id>
<severity>style</severity> <severity>style</severity>
<summary>Macros must be prefixed with MIGRAPH_</summary> <summary>Macros must be prefixed with MIGRAPHX_</summary>
</message> </message>
</rule> </rule>
<rule> <rule>
......
pfultz2/rocm-recipes pfultz2/rocm-recipes
pcre pcre
danmar/cppcheck@f965e5873 -DHAVE_RULES=1 danmar/cppcheck@575f62f39c1130f412d3cc11b0138c5057c451c0 -DHAVE_RULES=1
ROCm-Developer-Tools/HIP@fc22ef991ce7cb15821c8ccb4f03cdfc3e7e43cf ROCm-Developer-Tools/HIP@fc22ef991ce7cb15821c8ccb4f03cdfc3e7e43cf
python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36 python/cpython@v3.6.6 -X autotools -H sha256:92aa914572c695c0aeb01b0a214813f414da4b51a371234df514a74761f2bb36
-f requirements.txt -f requirements.txt
...@@ -17,9 +17,9 @@ struct iterator_for_range ...@@ -17,9 +17,9 @@ struct iterator_for_range
struct iterator struct iterator
{ {
base_iterator i; base_iterator i;
base_iterator operator*() { return i; } base_iterator operator*() const { return i; }
base_iterator operator++() { return ++i; } base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) { return i != rhs.i; } bool operator!=(const iterator& rhs) const { return i != rhs.i; }
}; };
iterator begin() iterator begin()
......
...@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms) ...@@ -214,7 +214,6 @@ void find_matches(program& p, Ms&&... ms)
bool match = false; bool match = false;
each_args( each_args(
[&](auto&& m) { [&](auto&& m) {
// cppcheck-suppress knownConditionTrueFalse
if(match) if(match)
return; return;
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(p, ins, m.matcher());
......
...@@ -640,49 +640,58 @@ struct as_shape ...@@ -640,49 +640,58 @@ struct as_shape
struct gather struct gather
{ {
std::size_t axis = 0; int axis = 0;
std::string name() const { return "gather"; } std::string name() const { return "gather"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto lens = inputs[0].lens(); auto lens = inputs[0].lens();
if(axis >= lens.size()) int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < -n_dim)
{ {
MIGRAPHX_THROW("Gather, axis is out of range."); MIGRAPHX_THROW("Gather: axis is out of range.");
} }
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (n_dim + axis) : axis;
auto type = inputs[0].type(); auto type = inputs[0].type();
lens[axis] = inputs[1].elements(); lens[axis_index] = inputs[1].elements();
return {type, lens}; return {type, lens};
} }
template <class T> template <class T>
void compute_index(const T& out_idx, void compute_index(const T& out_idx,
const int axis_index,
const std::vector<std::size_t>& vec_indices, const std::vector<std::size_t>& vec_indices,
const std::size_t max_dim, const std::size_t max_dim,
T& in_idx) const T& in_idx) const
{ {
in_idx = out_idx; in_idx = out_idx;
std::size_t idx = vec_indices.at(out_idx[axis]); std::size_t idx = vec_indices.at(out_idx[axis_index]);
if(idx >= max_dim) if(idx >= max_dim)
{ {
MIGRAPHX_THROW("Gather: indices are out of range in input tensor"); MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
} }
in_idx[axis] = idx; in_idx[axis_index] = idx;
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
// negative axis means counting dimensions from back
int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;
// max dimension in axis // max dimension in axis
std::size_t max_dim = args[0].get_shape().lens()[axis]; std::size_t max_dim = args[0].get_shape().lens()[axis_index];
std::vector<std::size_t> vec_indices; std::vector<std::size_t> vec_indices;
args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); }); args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
visit_all(result, args[0])([&](auto output, auto input) { visit_all(result, args[0])([&](auto output, auto input) {
std::vector<std::size_t> in_idx; std::vector<std::size_t> in_idx;
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
this->compute_index(idx, vec_indices, max_dim, in_idx); this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end()); output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
}); });
}); });
...@@ -961,7 +970,6 @@ struct scalar ...@@ -961,7 +970,6 @@ struct scalar
{ {
assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1); assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
auto t = inputs.at(0).type(); auto t = inputs.at(0).type();
auto input = inputs.at(0);
std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0); std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
return {t, scalar_bcast.lens(), strides}; return {t, scalar_bcast.lens(), strides};
} }
......
...@@ -54,6 +54,7 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -54,6 +54,7 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
f(i); f(i);
} }
}); });
// cppcheck-suppress unreadVariable
work += grainsize; work += grainsize;
return result; return result;
}); });
......
...@@ -24,7 +24,8 @@ struct onnx_parser ...@@ -24,7 +24,8 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -102,6 +103,15 @@ struct onnx_parser ...@@ -102,6 +103,15 @@ struct onnx_parser
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
} }
...@@ -109,7 +119,7 @@ struct onnx_parser ...@@ -109,7 +119,7 @@ struct onnx_parser
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f)
{ {
ops.emplace(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
...@@ -117,17 +127,15 @@ struct onnx_parser ...@@ -117,17 +127,15 @@ struct onnx_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
...@@ -188,7 +196,7 @@ struct onnx_parser ...@@ -188,7 +196,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -196,7 +204,7 @@ struct onnx_parser ...@@ -196,7 +204,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -376,7 +384,7 @@ struct onnx_parser ...@@ -376,7 +384,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
std::size_t axis = 0; int axis = 0;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -734,7 +742,7 @@ struct onnx_parser ...@@ -734,7 +742,7 @@ struct onnx_parser
} }
else else
{ {
throw std::runtime_error("Failed reading"); MIGRAPHX_THROW("Failed reading onnx file.");
} }
} }
...@@ -764,7 +772,7 @@ struct onnx_parser ...@@ -764,7 +772,7 @@ struct onnx_parser
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(get_name(p.second)); this->parse_node(p.first);
} }
} }
...@@ -790,23 +798,37 @@ struct onnx_parser ...@@ -790,23 +798,37 @@ struct onnx_parser
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); assert(name != input);
assert(name != iname); this->parse_node(input);
this->parse_node(iname); args.push_back(instructions.at(input));
args.push_back(instructions.at(iname));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args); result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
} }
else else
{ {
instructions[name] = ops[node.op_type()](get_attributes(node), args); result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
}
else
{
assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
} }
} }
} }
...@@ -821,25 +843,24 @@ struct onnx_parser ...@@ -821,25 +843,24 @@ struct onnx_parser
return result; return result;
} }
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraphx_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
result[get_name(node)] = node; if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output()) for(auto&& output : node.output())
{ {
result[output] = node; result[output] = node;
......
...@@ -84,7 +84,7 @@ struct memory_coloring_impl ...@@ -84,7 +84,7 @@ struct memory_coloring_impl
{ {
return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output"; return is_param(ins) && any_cast<builtin::param>(ins->get_operator()).parameter == "output";
} }
bool is_allocate(const instruction_ref ins) { return ins->name() == allocation_op; } bool is_allocate(const instruction_ref ins) const { return ins->name() == allocation_op; }
static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; } static bool is_outline(const instruction_ref ins) { return ins->name() == "@outline"; }
static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; } static bool is_literal(const instruction_ref ins) { return ins->name() == "@literal"; }
static bool is_check_context(const instruction_ref ins) static bool is_check_context(const instruction_ref ins)
......
...@@ -14,8 +14,9 @@ namespace device { ...@@ -14,8 +14,9 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis) int axis)
{ {
int axis_index = (axis < 0) ? (axis + output_shape.lens().size()) : axis;
visit_all(args.back(), args[0])([&](auto output, auto input) { visit_all(args.back(), args[0])([&](auto output, auto input) {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
args[1].visit([&](auto indices) { args[1].visit([&](auto indices) {
...@@ -27,7 +28,7 @@ argument gather(hipStream_t stream, ...@@ -27,7 +28,7 @@ argument gather(hipStream_t stream,
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(stream, nelements)([=](auto i) { gs_launch(stream, nelements)([=](auto i) {
auto lens = desc_output.multi(i); auto lens = desc_output.multi(i);
lens[axis] = indices_ptr[lens[axis]]; lens[axis_index] = indices_ptr[lens[axis_index]];
outptr[i] = inptr[desc_input.linear(lens)]; outptr[i] = inptr[desc_input.linear(lens)];
}); });
}); });
......
...@@ -344,7 +344,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r) ...@@ -344,7 +344,7 @@ void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()}; Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
// TODO: Insert ws allocation // TODO: Insert ws allocation
auto ws = cb.get_workspace(ctx); auto ws = cb.get_workspace(ctx);
(void)ws;
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
......
...@@ -13,7 +13,7 @@ namespace device { ...@@ -13,7 +13,7 @@ namespace device {
argument gather(hipStream_t stream, argument gather(hipStream_t stream,
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
std::vector<migraphx::argument> args, std::vector<migraphx::argument> args,
std::size_t axis); int axis);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -114,7 +114,7 @@ TEST_CASE(gather_test) ...@@ -114,7 +114,7 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
...@@ -134,7 +134,27 @@ TEST_CASE(gather_test) ...@@ -134,7 +134,27 @@ TEST_CASE(gather_test)
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}}; migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2}; std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res_data(4 * 5);
std::vector<float> golden = {0.5f, 2.5f, 3.5f, 5.5f, 6.5f, 8.5f};
result.visit([&](auto output) { res_data.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(res_data, golden));
}
{
migraphx::program p;
std::vector<float> data(3 * 3);
std::iota(data.begin(), data.end(), 0.5);
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto a0 = p.add_literal(migraphx::literal{s, data});
migraphx::shape s_indices{migraphx::shape::int32_type, {1, 2}};
std::vector<int> indices{0, 2};
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
......
...@@ -944,7 +944,23 @@ struct test_gather ...@@ -944,7 +944,23 @@ struct test_gather
std::vector<int> indices{1, 2, 2, 1}; std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s); auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices}); auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p;
}
};
struct test_gather_neg_axis
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
migraphx::shape s_indices{migraphx::shape::int32_type, {2, 2}};
std::vector<int> indices{1, 2, 2, 1};
auto a0 = p.add_parameter("data", s);
auto a1 = p.add_literal(migraphx::literal{s_indices, indices});
int axis = -1;
p.add_instruction(migraphx::op::gather{axis}, a0, a1); p.add_instruction(migraphx::op::gather{axis}, a0, a1);
return p; return p;
} }
...@@ -1090,4 +1106,6 @@ int main() ...@@ -1090,4 +1106,6 @@ int main()
verify_program<test_conv_bn_relu_pooling>(); verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>(); verify_program<test_conv_bn_relu_pooling2>();
verify_program<test_slice>(); verify_program<test_slice>();
verify_program<test_gather>();
verify_program<test_gather_neg_axis>();
} }
...@@ -111,7 +111,7 @@ struct lhs_expression ...@@ -111,7 +111,7 @@ struct lhs_expression
struct capture struct capture
{ {
template <class T> template <class T>
auto operator->*(const T& x) auto operator->*(const T& x) const
{ {
return make_lhs_expression(x); return make_lhs_expression(x);
} }
......
...@@ -296,7 +296,7 @@ TEST_CASE(max_test) ...@@ -296,7 +296,7 @@ TEST_CASE(max_test)
auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::max{}, input0, input1);
p.add_instruction(migraphx::op::max{}, l0, input2); p.add_instruction(migraphx::op::max{}, l0, input2);
auto prog = migraphx::parse_onnx("max_test.onnx"); migraphx::parse_onnx("max_test.onnx");
} }
TEST_CASE(acos_test) TEST_CASE(acos_test)
...@@ -319,7 +319,7 @@ TEST_CASE(min_test) ...@@ -319,7 +319,7 @@ TEST_CASE(min_test)
auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1); auto l0 = p.add_instruction(migraphx::op::min{}, input0, input1);
p.add_instruction(migraphx::op::min{}, l0, input2); p.add_instruction(migraphx::op::min{}, l0, input2);
auto prog = migraphx::parse_onnx("min_test.onnx"); migraphx::parse_onnx("min_test.onnx");
} }
TEST_CASE(atan_test) TEST_CASE(atan_test)
...@@ -417,7 +417,7 @@ TEST_CASE(gather_test) ...@@ -417,7 +417,7 @@ TEST_CASE(gather_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}}); auto l0 = p.add_parameter("data", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}}); auto l1 = p.add_parameter("indices", migraphx::shape{migraphx::shape::int32_type, {2, 3}});
std::size_t axis = 1; int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1); p.add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = migraphx::parse_onnx("gather_test.onnx"); auto prog = migraphx::parse_onnx("gather_test.onnx");
...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test) ...@@ -432,7 +432,7 @@ TEST_CASE(shape_gather_test)
p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens()); p.add_literal(migraphx::shape{migraphx::shape::int64_type, {3}}, l0->get_shape().lens());
migraphx::shape const_shape{migraphx::shape::int32_type, {1}}; migraphx::shape const_shape{migraphx::shape::int32_type, {1}};
auto l2 = p.add_literal(migraphx::literal{const_shape, {1}}); auto l2 = p.add_literal(migraphx::literal{const_shape, {1}});
std::size_t axis = 0; int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l2); p.add_instruction(migraphx::op::gather{axis}, l1, l2);
auto prog = migraphx::parse_onnx("shape_gather.onnx"); auto prog = migraphx::parse_onnx("shape_gather.onnx");
...@@ -558,7 +558,7 @@ TEST_CASE(group_conv_test) ...@@ -558,7 +558,7 @@ TEST_CASE(group_conv_test)
migraphx::op::convolution op; migraphx::op::convolution op;
op.group = 4; op.group = 4;
p.add_instruction(op, l0, l1); p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx"); migraphx::parse_onnx("group_conv_test.onnx");
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -217,7 +217,7 @@ TEST_CASE(gather) ...@@ -217,7 +217,7 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 1; int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}}, expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 6, 4, 5}},
migraphx::op::gather{axis}, migraphx::op::gather{axis},
input, input,
...@@ -227,7 +227,24 @@ TEST_CASE(gather) ...@@ -227,7 +227,24 @@ TEST_CASE(gather)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}}; migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
std::size_t axis = 4; int axis = -4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 3, 4, 5}},
migraphx::op::gather{axis},
input,
indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = 4;
throws_shape(migraphx::op::gather{axis}, input, indices);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
migraphx::shape indices{migraphx::shape::int32_type, {2, 3}};
int axis = -5;
throws_shape(migraphx::op::gather{axis}, input, indices); throws_shape(migraphx::op::gather{axis}, input, indices);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment