Unverified Commit 6599ffca authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Report test failures from CHECK macro (#1930)

parent 110eb00c
...@@ -117,6 +117,7 @@ struct onnx_parser ...@@ -117,6 +117,7 @@ struct onnx_parser
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false); parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
}; };
......
...@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser, ...@@ -357,10 +357,9 @@ parse_inputs(const onnx_parser& parser,
} }
shape s; shape s;
std::vector<std::size_t> dims;
if(parser.map_input_dims.count(name) > 0) if(parser.map_input_dims.count(name) > 0)
{ {
dims = parser.map_input_dims.at(name); std::vector<std::size_t> dims = parser.map_input_dims.at(name);
s = parser.parse_type(input.type(), dims); s = parser.parse_type(input.type(), dims);
} }
else if(parser.map_dyn_input_dims.count(name) > 0) else if(parser.map_dyn_input_dims.count(name) > 0)
...@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser, ...@@ -370,7 +369,7 @@ parse_inputs(const onnx_parser& parser,
} }
else else
{ {
s = parser.parse_type(input.type(), dims); s = parser.parse_type(input.type());
} }
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
...@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const ...@@ -553,14 +552,9 @@ literal onnx_parser::parse_tensor(const onnx::TensorProto& t) const
} }
MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type"); MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t, shape onnx_parser::parse_type(const onnx::TypeProto& t) const
const std::vector<std::size_t>& input_dims) const
{ {
shape::type_t shape_type = get_type(t.tensor_type().elem_type()); shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(not input_dims.empty())
{
return {shape_type, input_dims};
}
std::vector<shape::dynamic_dimension> dynamic_dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
...@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -590,6 +584,15 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return shape_from_dyn_dims(shape_type, dynamic_dims); return shape_from_dyn_dims(shape_type, dynamic_dims);
} }
shape onnx_parser::parse_type(const onnx::TypeProto& t,
const std::vector<std::size_t>& input_dims) const
{
shape::type_t shape_type = get_type(t.tensor_type().elem_type());
if(input_dims.empty())
return {shape_type};
return {shape_type, input_dims};
}
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
{ {
switch(dtype) switch(dtype)
......
...@@ -145,15 +145,15 @@ TEST_CASE(zero_parameter) ...@@ -145,15 +145,15 @@ TEST_CASE(zero_parameter)
TEST_CASE(set_scalar_parameter) TEST_CASE(set_scalar_parameter)
{ {
auto p1 = migraphx::parse_onnx("add_bcast_test.onnx"); auto p1 = migraphx::parse_onnx("implicit_add_bcast_test.onnx");
migraphx::shape s1(migraphx_shape_float_type, {3, 4}); migraphx::shape s1(migraphx_shape_float_type, {3, 4, 1});
auto param_shapes = p1.get_parameter_shapes(); auto param_shapes = p1.get_parameter_shapes();
auto s1_orig = param_shapes["1"]; auto s1_orig = param_shapes["1"];
CHECK(bool{s1 == s1_orig}); CHECK(bool{s1 == s1_orig});
migraphx::onnx_options option; migraphx::onnx_options option;
option.set_input_parameter_shape("1", {}); option.set_input_parameter_shape("1", {});
auto p2 = migraphx::parse_onnx("add_bcast_test.onnx", option); auto p2 = migraphx::parse_onnx("implicit_add_bcast_test.onnx", option);
migraphx::shape s_scalar(migraphx_shape_float_type); migraphx::shape s_scalar(migraphx_shape_float_type);
auto param_shapes_1 = p2.get_parameter_shapes(); auto param_shapes_1 = p2.get_parameter_shapes();
auto s_scalar_after = param_shapes_1["1"]; auto s_scalar_after = param_shapes_1["1"];
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <atomic>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
...@@ -342,11 +343,19 @@ inline std::ostream& operator<<(std::ostream& os, const color& c) ...@@ -342,11 +343,19 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
return os; return os;
} }
inline std::atomic<int>& failures()
{
// NOLINTNEXTLINE
static std::atomic<int> f = 0;
return f;
}
template <class T, class F> template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f) void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{ {
if(not bool(x.value())) if(not bool(x.value()))
{ {
failures()++;
std::cout << func << std::endl; std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl; std::cout << file << ":" << line << ":" << std::endl;
std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
...@@ -586,13 +595,21 @@ struct driver ...@@ -586,13 +595,21 @@ struct driver
{ {
try try
{ {
failures() = 0;
f(); f();
} }
// cppcheck-suppress EmptyCatchStatement
catch(const failure_error&) catch(const failure_error&)
{ {
msg = "Test failure";
} }
} }
if(msg.empty() and failures() != 0)
{
if(failures() == 1)
msg = "Test failure";
else
msg = std::to_string(failures()) + " test failures";
}
if(msg.empty()) if(msg.empty())
{ {
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
...@@ -683,10 +700,10 @@ inline void run(int argc, const char* argv[]) ...@@ -683,10 +700,10 @@ inline void run(int argc, const char* argv[])
#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ #define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define CHECK(...) \ #define CHECK(...) \
test::failed( \ test::failed( \
test::capture{}->*__VA_ARGS__, #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] { \ TEST_CAPTURE(__VA_ARGS__), #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] {})
})
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define EXPECT(...) \ #define EXPECT(...) \
test::failed(TEST_CAPTURE(__VA_ARGS__), \ test::failed(TEST_CAPTURE(__VA_ARGS__), \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment