Commit 50cfbcda authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into scatter-op

parents 4a24a2dd f60c3815
...@@ -151,6 +151,9 @@ jobs: ...@@ -151,6 +151,9 @@ jobs:
- debug - debug
- release - release
- codecov - codecov
exclude:
- os: ubuntu-16.04
configuration: debug
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
......
...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED) ...@@ -36,7 +36,7 @@ find_package(nlohmann_json 3.8.0 REQUIRED)
include(ROCMSetupVersion) include(ROCMSetupVersion)
rocm_setup_version(VERSION 1.2) rocm_setup_version(VERSION 1.3)
set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}) set(MIGRAPHX_SO_VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR})
option( BUILD_SHARED_LIBS "Build as a shared library" ON ) option( BUILD_SHARED_LIBS "Build as a shared library" ON )
......
...@@ -32,7 +32,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const ...@@ -32,7 +32,7 @@ std::vector<char> src_compiler::compile(const std::vector<src_file>& srcs) const
} }
} }
params += " -o" + out; params += " -o " + out;
td.execute(compiler, params); td.execute(compiler, params);
......
File mode changed from 100644 to 100755
...@@ -39,9 +39,7 @@ struct reverse ...@@ -39,9 +39,7 @@ struct reverse
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto lens = inputs[0].lens(); return inputs[0].with_lens(inputs[0].lens());
auto type = inputs[0].type();
return shape{type, lens};
} }
argument compute(const shape& s, std::vector<argument> args) const argument compute(const shape& s, std::vector<argument> args) const
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -25,8 +26,15 @@ struct step ...@@ -25,8 +26,15 @@ struct step
return pack(f(self.axes, "axes"), f(self.steps, "steps")); return pack(f(self.axes, "axes"), f(self.steps, "steps"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; } std::string name() const { return "step"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0); auto input = inputs.at(0);
......
...@@ -29,10 +29,6 @@ struct parse_slice : op_parser<parse_slice> ...@@ -29,10 +29,6 @@ struct parse_slice : op_parser<parse_slice>
migraphx::argument step_arg = args.back()->eval(); migraphx::argument step_arg = args.back()->eval();
check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice"); check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); }); step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return abs(s) == 1; }))
{
MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1 or -1");
}
} }
if(args.size() >= 4) if(args.size() >= 4)
...@@ -98,7 +94,16 @@ struct parse_slice : op_parser<parse_slice> ...@@ -98,7 +94,16 @@ struct parse_slice : op_parser<parse_slice>
auto ins = info.add_instruction(op, args[0]); auto ins = info.add_instruction(op, args[0]);
if(not raxes.empty()) if(not raxes.empty())
return info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins); ins = info.add_instruction(make_op("reverse", {{"axes", raxes}}), ins);
if(std::any_of(steps.begin(), steps.end(), [](auto s) { return std::abs(s) != 1; }))
{
std::vector<int64_t> nsteps;
std::transform(steps.begin(), steps.end(), std::back_inserter(nsteps), [](auto s) {
return std::abs(s);
});
return ins = info.add_instruction(
make_op("step", {{"axes", op.axes}, {"steps", nsteps}}), ins);
}
else else
return ins; return ins;
} }
......
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#ifdef __linux__
#include <unistd.h>
#endif
#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP #ifndef MIGRAPHX_GUARD_TEST_TEST_HPP
#define MIGRAPHX_GUARD_TEST_TEST_HPP #define MIGRAPHX_GUARD_TEST_TEST_HPP
...@@ -264,6 +268,32 @@ struct capture ...@@ -264,6 +268,32 @@ struct capture
} }
}; };
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
template <class T, class F> template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f) void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{ {
...@@ -271,7 +301,7 @@ void failed(T x, const char* msg, const char* func, const char* file, int line, ...@@ -271,7 +301,7 @@ void failed(T x, const char* msg, const char* func, const char* file, int line,
{ {
std::cout << func << std::endl; std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl; std::cout << file << ":" << line << ":" << std::endl;
std::cout << " FAILED: " << msg << " " std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
<< "[ " << x << " ]" << std::endl; << "[ " << x << " ]" << std::endl;
f(); f();
} }
...@@ -315,7 +345,7 @@ auto near(T px, U py, double ptol = 1e-6f) ...@@ -315,7 +345,7 @@ auto near(T px, U py, double ptol = 1e-6f)
using string_map = std::unordered_map<std::string, std::vector<std::string>>; using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword> template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword) string_map generic_parse(std::vector<std::string> as, Keyword keyword)
{ {
string_map result; string_map result;
...@@ -331,19 +361,22 @@ string_map parse(std::vector<std::string> as, Keyword keyword) ...@@ -331,19 +361,22 @@ string_map parse(std::vector<std::string> as, Keyword keyword)
{ {
flag = f.front(); flag = f.front();
result[flag]; // Ensure the flag exists result[flag]; // Ensure the flag exists
flag = f.back();
} }
} }
return result; return result;
} }
using test_case = std::function<void()>;
inline auto& get_test_cases() inline auto& get_test_cases()
{ {
// NOLINTNEXTLINE // NOLINTNEXTLINE
static std::vector<std::pair<std::string, std::function<void()>>> cases; static std::vector<std::pair<std::string, test_case>> cases;
return cases; return cases;
} }
inline void add_test_case(std::string name, std::function<void()> f) inline void add_test_case(std::string name, test_case f)
{ {
get_test_cases().emplace_back(std::move(name), std::move(f)); get_test_cases().emplace_back(std::move(name), std::move(f));
} }
...@@ -357,37 +390,243 @@ struct auto_register_test_case ...@@ -357,37 +390,243 @@ struct auto_register_test_case
} }
}; };
inline void run_test_case(const std::string& name, const std::function<void()>& f) struct failure_error
{ {
std::cout << "[ RUN ] " << name << std::endl; };
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
inline void run(int argc, const char* argv[]) [[noreturn]] inline void fail() { throw failure_error{}; }
struct driver
{ {
std::vector<std::string> as(argv + 1, argv + argc); driver()
{
add_flag({"--help", "-h"}, "Show help");
add_flag({"--list", "-l"}, "List all test cases");
add_flag({"--continue", "-c"}, "Continue after failure");
add_flag({"--quiet", "-q"}, "Don't print out extra output");
}
struct argument
{
std::vector<std::string> flags = {};
std::string help = "";
int nargs = 1;
};
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; }); void add_arg(const std::vector<std::string>& flags, const std::string& help = "")
auto cases = args[""];
if(cases.empty())
{ {
for(auto&& tc : get_test_cases()) arguments.push_back(argument{flags, help, 1});
run_test_case(tc.first, tc.second);
} }
else
void add_flag(const std::vector<std::string>& flags, const std::string& help = "")
{ {
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(), arguments.push_back(argument{flags, help, 0});
get_test_cases().end()); }
for(auto&& name : cases)
void show_help(const std::string& exe) const
{
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " ";
std::cout << exe << " <test-case>... <options>" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl;
std::cout << " ";
std::cout << color::fg_green << "<test-case>..." << color::reset;
std::cout << std::endl;
std::cout << " "
<< "Test case name to run" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : arguments)
{
std::string prefix = " ";
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset << std::endl;
std::cout << " " << arg.help << std::endl;
}
}
std::ostream& out() const
{
struct null_buffer : std::streambuf
{
virtual int overflow(int c) override { return c; }
};
static null_buffer buffer;
static std::ostream null_stream(&buffer);
if(quiet)
return null_stream;
return std::cout;
}
string_map parse(int argc, const char* argv[]) const
{
std::vector<std::string> args(argv + 1, argv + argc);
string_map keys;
for(auto&& arg : arguments)
{ {
auto f = m.find(name); for(auto&& flag : arg.flags)
if(f == m.end()) {
std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl; keys[flag] = {arg.flags.front()};
if(arg.nargs == 0)
keys[flag].push_back("");
}
}
auto result = generic_parse(args, [&](auto&& s) -> std::vector<std::string> {
if(keys.count(s) > 0)
return keys[s];
else else
run_test_case(name, f->second); return {};
});
result["__exe__"].push_back(argv[0]);
return result;
}
static std::string create_command(const string_map& args)
{
std::stringstream ss;
ss << args.at("__exe__").front();
if(args.count("") > 0)
{
for(auto&& arg : args.at(""))
ss << " \"" << arg << "\"";
}
for(auto&& p : args)
{
if(p.first == "__exe__")
continue;
if(p.first.empty())
continue;
ss << " " << p.first;
for(auto&& arg : p.second)
ss << " \"" << arg << "\"";
} }
return ss.str();
} }
static std::string fork(const std::string& name, string_map args)
{
std::string msg;
args[""] = {name};
args.erase("--continue");
args["--quiet"];
auto cmd = create_command(args);
auto r = std::system(cmd.c_str()); // NOLINT
if(r != 0)
msg = "Exited with " + std::to_string(r);
return msg;
}
void run_test_case(const std::string& name, const test_case& f, const string_map& args)
{
ran++;
out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name
<< color::reset << std::endl;
std::string msg;
if(args.count("--continue") > 0)
{
msg = fork(name, args);
}
else
{
try
{
f();
}
catch(const failure_error&)
{
msg = "Test failure";
}
}
if(msg.empty())
{
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
<< color::reset << std::endl;
}
else
{
failed.push_back(name);
out() << color::fg_red << "[ FAILED ] " << color::reset << color::bold << name
<< color::reset << ": " << color::fg_yellow << msg << color::reset << std::endl;
}
}
void run(int argc, const char* argv[])
{
auto args = parse(argc, argv);
if(args.count("--help") > 0)
{
show_help(args.at("__exe__").front());
return;
}
if(args.count("--list") > 0)
{
for(auto&& tc : get_test_cases())
out() << tc.first << std::endl;
return;
}
if(args.count("--quiet") > 0)
quiet = true;
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second, args);
}
else
{
std::unordered_map<std::string, test_case> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& iname : cases)
{
for(auto&& name : get_case_names(iname))
{
auto f = m.find(name);
if(f == m.end())
{
out() << color::fg_red << "[ ERROR ] Test case '" << name
<< "' not found." << color::reset << std::endl;
failed.push_back(name);
}
else
run_test_case(name, f->second, args);
}
}
}
out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran"
<< color::reset << std::endl;
if(not failed.empty())
{
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size()
<< " tests failed" << color::reset << std::endl;
for(auto&& name : failed)
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name
<< color::reset << std::endl;
std::exit(1);
}
}
std::function<std::vector<std::string>(const std::string&)> get_case_names =
[](const std::string& name) -> std::vector<std::string> { return {name}; };
std::vector<argument> arguments = {};
std::vector<std::string> failed = {};
std::size_t ran = 0;
bool quiet = false;
};
inline void run(int argc, const char* argv[])
{
driver d{};
d.run(argc, argv);
} }
} // namespace test } // namespace test
...@@ -404,7 +643,7 @@ inline void run(int argc, const char* argv[]) ...@@ -404,7 +643,7 @@ inline void run(int argc, const char* argv[])
__PRETTY_FUNCTION__, \ __PRETTY_FUNCTION__, \
__FILE__, \ __FILE__, \
__LINE__, \ __LINE__, \
&std::abort) &test::fail)
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0) #define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
#include "test.hpp"
int main() {} int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3603,7 +3603,7 @@ def slice_5arg_reverse_test(): ...@@ -3603,7 +3603,7 @@ def slice_5arg_reverse_test():
outputs=['arg_axis'], outputs=['arg_axis'],
value=axis_tensor) value=axis_tensor)
end = np.array([-1, -1]) end = np.array([-5, -1])
end_tensor = helper.make_tensor(name="end", end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32, data_type=TensorProto.INT32,
dims=end.shape, dims=end.shape,
...@@ -3613,7 +3613,60 @@ def slice_5arg_reverse_test(): ...@@ -3613,7 +3613,60 @@ def slice_5arg_reverse_test():
outputs=['arg_end'], outputs=['arg_end'],
value=end_tensor) value=end_tensor)
start = np.array([-5, -3]) start = np.array([-1, -3])
start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32,
dims=start.shape,
vals=start.astype(int))
arg_start = helper.make_node("Constant",
inputs=[],
outputs=['arg_start'],
value=start_tensor)
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [4, 2])
node = onnx.helper.make_node(
'Slice',
inputs=['0', 'arg_start', 'arg_end', 'arg_axis', 'arg_step'],
outputs=['1'])
return ([arg_step, arg_axis, arg_end, arg_start, node], [x], [y])
@onnx_test
def slice_5arg_step_test():
step = np.array([-2, 2])
step_tensor = helper.make_tensor(name="step",
data_type=TensorProto.INT32,
dims=step.shape,
vals=step.astype(int))
arg_step = helper.make_node("Constant",
inputs=[],
outputs=['arg_step'],
value=step_tensor)
axis = np.array([-1, -2])
axis_tensor = helper.make_tensor(name="axis",
data_type=TensorProto.INT32,
dims=axis.shape,
vals=axis.astype(int))
arg_axis = helper.make_node("Constant",
inputs=[],
outputs=['arg_axis'],
value=axis_tensor)
end = np.array([-5, -1])
end_tensor = helper.make_tensor(name="end",
data_type=TensorProto.INT32,
dims=end.shape,
vals=end.astype(int))
arg_end = helper.make_node("Constant",
inputs=[],
outputs=['arg_end'],
value=end_tensor)
start = np.array([-1, -3])
start_tensor = helper.make_tensor(name="start", start_tensor = helper.make_tensor(name="start",
data_type=TensorProto.INT32, data_type=TensorProto.INT32,
dims=start.shape, dims=start.shape,
......
...@@ -3331,10 +3331,11 @@ TEST_CASE(slice_5arg_reverse_test) ...@@ -3331,10 +3331,11 @@ TEST_CASE(slice_5arg_reverse_test)
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, 1}}); mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, 1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}}); mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -1}}); mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -3}}); mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}});
auto slice_out = mm->add_instruction( auto slice_out = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1, -2}}, {"starts", {0, -3}}, {"ends", {-4, -1}}}), migraphx::make_op("slice",
{{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}),
l0); l0);
auto ret = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out); auto ret = mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out);
mm->add_return({ret}); mm->add_return({ret});
...@@ -3344,6 +3345,30 @@ TEST_CASE(slice_5arg_reverse_test) ...@@ -3344,6 +3345,30 @@ TEST_CASE(slice_5arg_reverse_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(slice_5arg_step_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 5}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-2, 2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -2}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-5, -1}});
mm->add_literal({{migraphx::shape::int32_type, {2}}, {-1, -3}});
auto slice_out = mm->add_instruction(
migraphx::make_op("slice",
{{"axes", {-1, -2}}, {"starts", {-4, -3}}, {"ends", {2147483647, -1}}}),
l0);
auto reverse_out =
mm->add_instruction(migraphx::make_op("reverse", {{"axes", {-1}}}), slice_out);
auto step_out = mm->add_instruction(
migraphx::make_op("step", {{"axes", {-1, -2}}, {"steps", {2, 2}}}), reverse_out);
mm->add_return({step_out});
auto prog = migraphx::parse_onnx("slice_5arg_step_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(slice_max_end_test) TEST_CASE(slice_max_end_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
Barg_axis"Constant*, Barg_axis"Constant*,
value* *Baxis value* *Baxis
@arg_end"Constant*+ @arg_end"Constant*+
value**Bend value**Bend
D arg_start"Constant*- D arg_start"Constant*-
value*!*Bstart value*!*Bstart
5 5
0 0
arg_start arg_start
......
slice_5arg_step_test:
9arg_step"Constant*#
value** Bstep
Barg_axis"Constant*,
value* *Baxis
@arg_end"Constant*+
value**Bend
D arg_start"Constant*-
value*!*Bstart
5
0
arg_start
arg_end
arg_axis
arg_step1"Sliceslice_5arg_step_testZ
0


b
1


B
\ No newline at end of file
slice_5arg_test: slice_5arg_test:
0arg_step"Constant* 0arg_step"Constant*
value**Bstep value**Bstep
Barg_axis"Constant*, Barg_axis"Constant*,
...@@ -20,4 +20,4 @@ D arg_start"Constant*- ...@@ -20,4 +20,4 @@ D arg_start"Constant*-
1 1
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -408,6 +408,85 @@ TEST_CASE(selu_test) ...@@ -408,6 +408,85 @@ TEST_CASE(selu_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(slice_test)
{
migraphx::program p = migraphx::parse_onnx("slice_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {3, 2}};
std::vector<float> data = {0, 1, 2, 3, 4, 5};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {2, 3};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_5arg_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {10, 11, 12, 13, 15, 16, 17, 18};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_reverse_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_reverse_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 13, 12, 11, 19, 18, 17, 16};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(slice_step_test)
{
migraphx::program p = migraphx::parse_onnx("slice_5arg_step_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sh_data{migraphx::shape::float_type, {5, 5}}; // start
std::vector<float> data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(sh_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {14, 12};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(upsample_test) TEST_CASE(upsample_test)
{ {
migraphx::program p = migraphx::parse_onnx("upsample_test.onnx"); migraphx::program p = migraphx::parse_onnx("upsample_test.onnx");
......
This diff is collapsed.
...@@ -44,3 +44,23 @@ void auto_print::set_terminate_handler(const std::string& name) ...@@ -44,3 +44,23 @@ void auto_print::set_terminate_handler(const std::string& name)
get_handler(tname)(); get_handler(tname)();
}); });
} }
static bool in_exception()
{
#if __cplusplus >= 201703L
return std::uncaught_exceptions() > 0;
#else
return std::uncaught_exception();
#endif
}
auto_print::~auto_print()
{
if(in_exception())
{
std::cout << std::endl;
for(const auto& tname : migraphx::get_targets())
get_handler(tname)();
}
get_handler(name) = [] {};
}
...@@ -15,10 +15,7 @@ struct auto_print ...@@ -15,10 +15,7 @@ struct auto_print
get_handler(name) = [&x] { std::cout << x << std::endl; }; get_handler(name) = [&x] { std::cout << x << std::endl; };
} }
~auto_print() ~auto_print();
{
get_handler(name) = [] {};
}
}; };
#endif #endif
#include "run_verify.hpp" #include "run_verify.hpp"
#include "auto_print.hpp" #include "auto_print.hpp"
#include "verify_program.hpp" #include "verify_program.hpp"
#include "test.hpp"
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
...@@ -121,7 +122,6 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -121,7 +122,6 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
{ {
using result_future = using result_future =
std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>; std::future<std::pair<migraphx::program, std::vector<migraphx::argument>>>;
std::cout << "[ RUN ] " << name << std::endl;
auto_print::set_terminate_handler(name); auto_print::set_terminate_handler(name);
std::vector<std::pair<std::string, result_future>> results; std::vector<std::pair<std::string, result_future>> results;
std::vector<std::string> target_names; std::vector<std::string> target_names;
...@@ -180,25 +180,27 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con ...@@ -180,25 +180,27 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
std::cout << tname << ":\n" << cp << std::endl; std::cout << tname << ":\n" << cp << std::endl;
std::cout << std::endl; std::cout << std::endl;
} }
EXPECT(passed);
} }
} }
std::set_terminate(nullptr); std::set_terminate(nullptr);
std::cout << "[ COMPLETE ] " << name << std::endl;
} }
void run_verify::run(int argc, const char* argv[]) const void run_verify::run(int argc, const char* argv[]) const
{ {
std::set<std::string> args(argv + 1, argv + argc); std::unordered_map<std::string, std::vector<std::string>> labels;
const auto& ps = get_programs(); for(auto&& p : get_programs())
for(auto&& p : ps)
{ {
if(not args.empty()) labels[p.section].push_back(p.name);
{ test::add_test_case(p.name, [=] { verify(p.name, p.get_program()); });
if(args.count(p.name) == 0 and args.count(p.section) == 0)
continue;
}
verify(p.name, p.get_program());
} }
test::driver d{};
d.get_case_names = [&](const std::string& name) -> std::vector<std::string> {
if(labels.count(name) > 0)
return labels.at(name);
return {name};
};
d.run(argc, argv);
} }
void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; } void run_verify::disable_parallel_for(const std::string& name) { info[name].parallel = false; }
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice_reverse : verify_program<test_slice_reverse>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {3, 5}};
auto x = mm->add_parameter("x", s);
auto slice_out = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {0, 2}}, {"ends", {2, -1}}}),
x);
mm->add_instruction(migraphx::make_op("reverse", {{"axes", {0}}}), slice_out);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment