Unverified Commit 476ed17c authored by Brian Pickrell's avatar Brian Pickrell Committed by GitHub
Browse files

Merge branch 'develop' into rand_uniform

parents f4f9d711 6f1c947f
 slice_var_input_static1:
)
data
starts
ends
axesoutput"Sliceslice_var_input_static1Z
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
 slice_var_input_steps_error:
0arg_step"Constant*
value**Bstep
3
data
starts
ends
axes
arg_stepoutput"Sliceslice_var_input_steps_errorZ
data


Z
starts

Z
ends

Z
axes

b
output


B
\ No newline at end of file
......@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/pass_manager.hpp>
......
......@@ -24,7 +24,8 @@
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/op/common.hpp>
#include <sstream>
#include <migraphx/make_op.hpp>
......@@ -156,13 +157,13 @@ TEST_CASE(broadcast)
{
std::vector<std::size_t> lens{1, 1};
migraphx::shape input{migraphx::shape::float_type, {2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
}
{
std::vector<std::size_t> lens{2, 2};
migraphx::shape input{migraphx::shape::float_type, {1, 2}};
throws_shape(migraphx::op::broadcast{1, lens}, input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), input);
}
{
......@@ -1252,36 +1253,45 @@ TEST_CASE(inconsistent_attr_shape)
input);
}
template <class T>
void test_softmax_variations()
void test_softmax_variations(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{0}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{1}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{2}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}, T{3}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::make_op(name, {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
throws_shape(T{axis}, input);
throws_shape(migraphx::make_op(name, {{"axis", axis}}), input);
}
}
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations("logsoftmax"); }
TEST_CASE(softmax) { test_softmax_variations("softmax"); }
TEST_CASE(lstm)
{
......@@ -2342,47 +2352,54 @@ TEST_CASE(dqlinear_mismatch_type)
throws_shape(migraphx::make_op("dequantizelinear"), input, scales, zeros);
}
template <class T>
void test_reduce_ops()
void test_reduce_ops(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}}, T{{0, 1, 2, 3}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 1}},
migraphx::make_op(name, {{"axes", {0, 1, 2, 3}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}}, T{{2, 3}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 1, 1}},
migraphx::make_op(name, {{"axes", {2, 3}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}}, T{{0}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 3, 4, 5}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}}, T{{-1}}, input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{{4}}, input);
throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
}
}
// dynamic shape
template <class T>
void test_dyn_reduce_ops()
void test_dyn_reduce_ops(const std::string& name)
{
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{2, 3, {3}}, {1, 1}})},
T{{-1}},
migraphx::make_op(name, {{"axes", {-1}}}),
input);
}
{
......@@ -2390,7 +2407,7 @@ void test_dyn_reduce_ops()
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {2, 4, {4}}})},
T{{0}},
migraphx::make_op(name, {{"axes", {0}}}),
input);
}
{
......@@ -2399,24 +2416,24 @@ void test_dyn_reduce_ops()
expect_shape(
migraphx::shape{migraphx::shape::float_type,
std::vector<migraphx::shape::dynamic_dimension>({{1, 1}, {1, 1}})},
T{{}},
migraphx::make_op(name),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {{2, 3, {3}}, {2, 4, {4}}}};
throws_shape(T{{4}}, input);
throws_shape(migraphx::make_op(name, {{"axes", {4}}}), input);
}
}
TEST_CASE(reduce_max) { test_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean) { test_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod) { test_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum) { test_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max) { test_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean) { test_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod) { test_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum) { test_reduce_ops("reduce_sum"); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_max>(); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_mean>(); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_prod>(); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops<migraphx::op::reduce_sum>(); }
TEST_CASE(reduce_max_dyn) { test_dyn_reduce_ops("reduce_max"); }
TEST_CASE(reduce_mean_dyn) { test_dyn_reduce_ops("reduce_mean"); }
TEST_CASE(reduce_prod_dyn) { test_dyn_reduce_ops("reduce_prod"); }
TEST_CASE(reduce_sum_dyn) { test_dyn_reduce_ops("reduce_sum"); }
TEST_CASE(reshape_shape)
{
......@@ -2836,7 +2853,7 @@ TEST_CASE(select_module_dyn)
input);
}
TEST_CASE(slice_shape)
TEST_CASE(slice_static_shape)
{
migraphx::shape input{migraphx::shape::int32_type, {2, 2, 3}};
expect_shape(migraphx::shape{migraphx::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
......@@ -2854,6 +2871,67 @@ TEST_CASE(slice_shape)
input);
}
TEST_CASE(slice_var_inputs_static_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_static_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 3}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_var_inputs_static_error0)
{
migraphx::shape input{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {3}};
throws_shape(migraphx::make_op("slice"), input, starts, ends, axes);
}
TEST_CASE(slice_var_inputs_dyn_shape0)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{3, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice", {{"axes", {1, 2}}}),
input,
starts,
ends);
}
TEST_CASE(slice_var_inputs_dyn_shape1)
{
migraphx::shape input{migraphx::shape::float_type, {{3, 6}, {2, 4, {2, 4}}, {2, 4, {2, 4}}}};
migraphx::shape starts{migraphx::shape::int64_type, {2}};
migraphx::shape ends{migraphx::shape::int64_type, {2}};
migraphx::shape axes{migraphx::shape::int64_type, {2}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{0, 6}, {0, 4}, {0, 4}}},
migraphx::make_op("slice"),
input,
starts,
ends,
axes);
}
TEST_CASE(slice_dyn_shape0)
{
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 7}, {2, 3}}};
......@@ -2884,7 +2962,7 @@ TEST_CASE(slice_dyn_shape2)
TEST_CASE(slice_dyn_shape3)
{
// TODO: When variable dimension slicing is allowed, Slice to a size smaller than min.
// TODO: When non-fixed dimension slicing is allowed, Slice to a size smaller than min.
// Until then, this action is an error.
migraphx::shape input{migraphx::shape::int32_type, {{2, 3}, {7, 8}, {2, 3}}};
throws_shape(migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}),
......@@ -2915,8 +2993,6 @@ TEST_CASE(slice_dyn_shape5)
input);
}
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}};
......
......@@ -22,7 +22,6 @@
* THE SOFTWARE.
*/
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/pad_calc.hpp>
#include "test.hpp"
......
......@@ -24,7 +24,6 @@
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/register_target.hpp>
......
......@@ -8308,6 +8308,115 @@ TEST_CASE(slice_test)
}
}
TEST_CASE(slice_var_inputs_static0)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {1};
std::vector<int32_t> end_data = {3};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static1)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int32_t> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::int32_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), l0, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int32_t> start_data = {-2};
std::vector<int32_t> end_data = {2831};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int32_t> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int32_t> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_static2)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<float> data(2 * 2 * 3);
std::iota(data.begin(), data.end(), 0);
migraphx::shape s0{migraphx::shape::float_type, {2, 2, 3}};
auto l0 = mm->add_literal(migraphx::literal{s0, data});
migraphx::shape s1{migraphx::shape::int64_type, {3}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
auto axes = mm->add_parameter("axes", s1);
mm->add_instruction(migraphx::make_op("slice"), l0, starts, ends, axes);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
std::vector<int64_t> start_data = {0, 0, 0};
std::vector<int64_t> end_data = {2, 2, 2};
std::vector<int64_t> axes_data = {0, 1, 2};
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
params["axes"] = migraphx::argument(s1, axes_data.data());
auto result = p.eval(params).back();
std::vector<float> gold = {0, 1, 3, 4, 6, 7, 9, 10};
std::vector<float> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_var_inputs_dyn)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::int32_type, {{2, 4, {2, 4}}, {2, 4, {2, 4}}, {3, 8}}};
auto input = mm->add_parameter("input", s0);
migraphx::shape s1{migraphx::shape::int32_type, {1}};
auto starts = mm->add_parameter("starts", s1);
auto ends = mm->add_parameter("ends", s1);
mm->add_instruction(migraphx::make_op("slice", {{"axes", {2}}}), input, starts, ends);
p.compile(migraphx::make_target("ref"));
migraphx::parameter_map params;
migraphx::shape s2{migraphx::shape::int32_type, {2, 2, 3}};
std::vector<int> input_data(2 * 2 * 3);
std::iota(input_data.begin(), input_data.end(), 0);
std::vector<int> start_data = {1};
std::vector<int> end_data = {3};
params["input"] = migraphx::argument(s2, input_data.data());
params["starts"] = migraphx::argument(s1, start_data.data());
params["ends"] = migraphx::argument(s1, end_data.data());
auto result = p.eval(params).back();
std::vector<int> gold = {1, 2, 4, 5, 7, 8, 10, 11};
std::vector<int> results_vector(2 * 2 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold));
}
TEST_CASE(slice_dyn_test0)
{
// Slice a single dynamic dimension. ax1 slice limits are smaller than min; ax2 "ends" is
......
......@@ -37,6 +37,17 @@
bool is_quantizelinear(migraphx::instruction& ins) { return ins.name() == "quantizelinear"; }
bool is_dequantizelinear(migraphx::instruction& ins) { return ins.name() == "dequantizelinear"; }
bool is_clip_scalar(migraphx::instruction& ins)
{
if(ins.name() == "clip")
{
assert(ins.inputs().size() > 1);
return (std::all_of(ins.inputs().begin() + 1, ins.inputs().end(), [](auto input) {
return input->get_shape().scalar();
}));
}
return false;
}
void run_pass(migraphx::module& m) { migraphx::run_passes(m, {migraphx::rewrite_quantization{}}); }
......@@ -70,6 +81,8 @@ TEST_CASE(quantizelinear)
EXPECT(eval(p1) == eval(p2));
EXPECT(any_of(*p1.get_main_module(), &is_quantizelinear));
EXPECT(none_of(*p2.get_main_module(), &is_quantizelinear));
// ensure clip literals created in quantized program are scalar
EXPECT(any_of(*p2.get_main_module(), &is_clip_scalar));
}
TEST_CASE(dequantizelinear)
......
......@@ -24,7 +24,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
......@@ -153,7 +153,7 @@ TEST_CASE(simplify_add_broadcast1)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", outer);
......@@ -188,7 +188,7 @@ TEST_CASE(simplify_add_broadcast2)
{
migraphx::shape inner{migraphx::shape::int32_type, {2}};
migraphx::shape outer{migraphx::shape::int32_type, {1, 2, 3, 3}};
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 2, 3, 3}}});
auto create_program = [&] {
migraphx::module m;
auto x = m.add_parameter("x", outer);
......@@ -539,7 +539,7 @@ TEST_CASE(simplify_conv_add)
TEST_CASE(simplify_inner_broadcast1)
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1}});
......@@ -564,7 +564,7 @@ TEST_CASE(simplify_inner_broadcast1)
TEST_CASE(simplify_inner_broadcast2)
{
auto b = migraphx::op::multibroadcast{{2, 1, 4, 5}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 4, 5}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 1, 1, 1}});
......@@ -589,7 +589,7 @@ TEST_CASE(simplify_inner_broadcast2)
TEST_CASE(simplify_inner_broadcast_scalar)
{
auto b = migraphx::op::multibroadcast{{32, 384}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {32, 384}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
......@@ -605,7 +605,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{1, 384}}, y);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -615,7 +616,7 @@ TEST_CASE(simplify_inner_broadcast_scalar)
TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto b = migraphx::op::multibroadcast{{2, 384, 768}};
auto b = migraphx::make_op("multibroadcast", {{"out_lens", {2, 384, 768}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
......@@ -631,7 +632,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
auto sumb = m2.add_instruction(b, sum);
m2.add_instruction(pass_op{}, sumb);
......@@ -641,8 +643,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
TEST_CASE(simplify_inner_broadcast_different_broadcasts)
{
auto b = migraphx::op::broadcast{1, {1, 24, 112, 112}};
auto mb = migraphx::op::multibroadcast{{1, 24, 112, 112}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 24, 112, 112}}});
auto mb = migraphx::make_op("multibroadcast", {{"out_lens", {1, 24, 112, 112}}});
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::int32_type, {24}});
......@@ -891,7 +893,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -907,7 +909,7 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -926,7 +928,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -944,7 +946,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 2, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -964,7 +966,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto one = m1.add_literal(1);
......@@ -982,7 +984,7 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 1, 4, 5}}});
auto x = m2.add_parameter("x", s);
auto y = m2.add_parameter("y", s);
auto one = m2.add_literal(1);
......@@ -1695,7 +1697,7 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1716,7 +1718,7 @@ TEST_CASE(simplify_split_add_relu)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1846,8 +1848,8 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto r = migraphx::make_op("reshape", {{"dims", {3, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1870,7 +1872,7 @@ TEST_CASE(simplify_split_add_relu_reshape)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -1894,7 +1896,7 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::module m1;
{
auto r = migraphx::op::reshape{{3, 2, 4}};
auto r = migraphx::make_op("reshape", {{"dims", {3, 2, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -1926,7 +1928,7 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1954,7 +1956,7 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {3}}}), input);
......@@ -1982,7 +1984,7 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2010,7 +2012,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2031,7 +2033,7 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto one = m2.add_literal(1);
auto two = m2.add_literal(2);
......@@ -2049,7 +2051,7 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4, 3}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1, 3}}, {"starts", {0, 0}}, {"ends", {1, 3}}}),
......@@ -2078,7 +2080,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2100,7 +2102,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2126,7 +2128,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::module m1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 1, 4}}});
auto input = m1.add_parameter("input", s);
auto x = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2149,7 +2151,7 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
migraphx::module m2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto b = migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {3, 2, 4}}});
auto input = m2.add_parameter("input", s);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), input);
......@@ -2189,16 +2191,16 @@ TEST_CASE(simplify_split_between_add)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz)
void test_dot_horiz(migraphx::shape::type_t type, const std::string& dot_type)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
auto s = migraphx::shape{type, {3, 2, 2}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s);
auto a = m1.add_literal(migraphx::generate_literal(s, 0));
auto b = m1.add_literal(migraphx::generate_literal(s, 1));
auto x = m1.add_instruction(migraphx::make_op("dot"), input, a);
auto y = m1.add_instruction(migraphx::make_op("dot"), input, b);
auto x = m1.add_instruction(migraphx::make_op(dot_type), input, a);
auto y = m1.add_instruction(migraphx::make_op(dot_type), input, b);
auto sum = m1.add_instruction(migraphx::make_op("add"), x, y);
m1.add_instruction(pass_op{}, sum);
}
......@@ -2210,7 +2212,7 @@ TEST_CASE(simplify_dot_horiz)
auto a = m2.add_literal(migraphx::generate_literal(s, 0));
auto b = m2.add_literal(migraphx::generate_literal(s, 1));
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), a, b);
auto dot = m2.add_instruction(migraphx::make_op("dot"), input, concat);
auto dot = m2.add_instruction(migraphx::make_op(dot_type), input, concat);
auto x = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {2}}}), dot);
auto y = m2.add_instruction(
......@@ -2221,6 +2223,10 @@ TEST_CASE(simplify_dot_horiz)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(simplify_dot_horiz) { test_dot_horiz(migraphx::shape::int32_type, "dot"); }
TEST_CASE(simplify_quant_dot_horiz) { test_dot_horiz(migraphx::shape::int8_type, "quant_dot"); }
TEST_CASE(simplify_dot_horiz_same_constant)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
......
......@@ -24,7 +24,6 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/generate.hpp>
#include <basic_ops.hpp>
......@@ -477,7 +476,7 @@ TEST_CASE(concat_multibroadcasts1)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_multibroadcasts2)
......@@ -500,7 +499,7 @@ TEST_CASE(concat_multibroadcasts2)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 0);
}
TEST_CASE(concat_multibroadcasts3)
......@@ -523,7 +522,7 @@ TEST_CASE(concat_multibroadcasts3)
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 2);
}
TEST_CASE(concat_multibroadcasts4)
......@@ -559,7 +558,7 @@ TEST_CASE(concat_transpose1)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 3);
}
TEST_CASE(concat_transpose2)
......@@ -583,7 +582,7 @@ TEST_CASE(concat_transpose2)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose3)
......@@ -607,7 +606,7 @@ TEST_CASE(concat_transpose3)
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
EXPECT(new_concat->get_operator().to_value()["axis"].to<int>() == 1);
}
TEST_CASE(concat_transpose4)
......
......@@ -37,7 +37,6 @@
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/reduce_mean.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/serialize.hpp>
......@@ -840,12 +839,8 @@ TEST_CASE(slice_test)
mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op;
op.starts = {1, 0};
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
mm->add_instruction(op, l0);
mm->add_instruction(
migraphx::make_op("slice", {{"starts", {1, 0}}, {"ends", {3, 10}}, {"axes", {0, 1}}}), l0);
auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog);
......@@ -975,13 +970,10 @@ TEST_CASE(stridedslice_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = mm->add_instruction(op, l1);
auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 0, 0, 0}}, {"ends", {1, 1, 1, 5}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto shrink_axis = 1;
mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {shrink_axis}}}), l2);
auto prog = optimize_tf("stridedslice_test.pb", true);
......@@ -995,12 +987,6 @@ TEST_CASE(stridedslice_masks_test)
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
op.ends = {1, 3, 3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0});
......@@ -1011,7 +997,10 @@ TEST_CASE(stridedslice_masks_test)
auto l1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l2 = mm->add_instruction(op, l1);
auto l2 = mm->add_instruction(
migraphx::make_op(
"slice", {{"starts", {0, 1, 1, 0}}, {"ends", {1, 3, 3, 10}}, {"axes", {0, 1, 2, 3}}}),
l1);
auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), l2);
mm->add_return({l3});
......
......@@ -25,7 +25,7 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/make_op.hpp>
struct gemm_literal : verify_program<gemm_literal>
{
......@@ -38,7 +38,7 @@ struct gemm_literal : verify_program<gemm_literal>
auto a = mm->add_literal(migraphx::generate_literal(a_shape));
auto b = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::op::dot{}, a, b);
mm->add_instruction(migraphx::make_op("dot"), a, b);
return p;
}
......
......@@ -31,7 +31,7 @@ pip3 install -r requirements-dev.txt
# Add newer cmake to the path
export PATH="/opt/cmake/bin:$PATH"
export CXXFLAGS="-D__HIP_PLATFORM_AMD__=1 -w"
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
./build.sh --config Release --cmake_extra_defines CMAKE_HIP_COMPILER=/opt/rocm/llvm/bin/clang++ --update --build --build_wheel --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --skip_tests --rocm_home /opt/rocm --use_migraphx --migraphx_home /opt/rocm --rocm_version=`cat /opt/rocm/.info/version-dev` --allow_running_as_root
cd build/Linux/Release
#Add test launcher for onnxrt tests
......
FROM registry.suse.com/suse/sle15:15.4
RUN sh -c 'echo -e "\
[rocm]\n\
name=rocm\n\
baseurl=https://repo.radeon.com/rocm/zyp/5.5/main\n\
enabled=1\n\
gpgcheck=1\n\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\
" > /etc/zypp/repos.d/rocm.repo'
RUN cat /etc/zypp/repos.d/rocm.repo
RUN zypper -n --gpg-auto-import-keys refresh
RUN zypper install -y -t pattern devel_basis enhanced_base
RUN zypper --gpg-auto-import-keys install -y \
doxygen \
gcc-c++ \
gdb \
git \
python3-pip
# Workaround broken rocm packages
RUN ln -s /opt/rocm-* /opt/rocm
RUN echo "/opt/rocm/lib" > /etc/ld.so.conf.d/rocm.conf
RUN echo "/opt/rocm/llvm/lib" > /etc/ld.so.conf.d/rocm-llvm.conf
RUN ldconfig
ENV LC_ALL=C.UTF-8
ENV LANG=C.UTF-8
# Install yapf
RUN pip3 install yapf==0.28.0
# Install doc requirements
# ADD docs/.sphinx/requirements.txt /doc-requirements.txt
# RUN pip3 install -r /doc-requirements.txt
# Install dependencies
ADD dev-requirements.txt /dev-requirements.txt
ADD requirements.txt /requirements.txt
ADD rbuild.ini /rbuild.ini
COPY ./tools/install_prereqs.sh /
RUN /install_prereqs.sh /usr/local / && rm /install_prereqs.sh
......@@ -31,9 +31,30 @@ set -e
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
source /etc/os-release
if [[ ("${ID}" == "sles") ]]; then
zypper -n --gpg-auto-import-keys install -y \
cmake \
miopen-hip-devel \
openmp-extras-devel \
python3-devel \
python3-pip \
rocblas-devel \
rocm-cmake
else
# Need pip3 and Python headers to build dependencies
apt update && apt install -y \
cmake \
libnuma-dev \
miopen-hip-dev \
openmp-extras \
python3-dev \
python3-pip \
rocblas-dev \
rocm-cmake
fi
# Need pip3 and Python headers to build dependencies
apt update && apt install -y python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
pip3 install setuptools wheel
......@@ -56,9 +77,11 @@ echo "Dependencies are installed at $PREFIX"
# Install deps with rbuild
rbuild prepare -d $PREFIX -s develop
if [[ ("${ID}" != "sles") ]]; then
export CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=ON"
pip3 install onnx==1.10.2 numpy==1.21.6 typing==3.7.4 pytest==6.0.1 packaging==23.0
# pin version of protobuf in Python for onnx runtime unit tests between dist versions
pip3 install protobuf==3.20.0
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment