"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "9f375dfcdb2affd0cd5376af9c3638eea31f5d00"
Commit e76bd729 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into batchnorm_onnx

parents 4e4460dd 6bb6b72e
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/errors.hpp> #include <migraph/errors.hpp>
#include <migraph/argument.hpp>
namespace migraph { namespace migraph {
......
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#include <migraph/program.hpp>
namespace migraph {
template <class T>
struct check_context
{
struct op
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
std::string name() const { return "check_context"; }
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
};
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp>
#include <algorithm>
namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_CONTEXT_HPP #ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP #define MIGRAPH_GUARD_CONTEXT_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
#ifndef MIGRAPH_GUARD_ERASE_HPP #ifndef MIGRAPH_GUARD_ERASE_HPP
#define MIGRAPH_GUARD_ERASE_HPP #define MIGRAPH_GUARD_ERASE_HPP
#include <algorithm>
namespace migraph { namespace migraph {
/** /**
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
...@@ -3,98 +3,13 @@ ...@@ -3,98 +3,13 @@
#include <array> #include <array>
#include <migraph/operation.hpp> #include <migraph/operation.hpp>
#include <migraph/check_shapes.hpp>
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp> #include <migraph/streamutils.hpp>
#include <cmath> #include <cmath>
namespace migraph { namespace migraph {
struct check_shapes
{
const std::vector<shape>* shapes;
const std::string name;
check_shapes(const std::vector<shape>& s) : shapes(&s) {}
template <class Op>
check_shapes(const std::vector<shape>& s, const Op& op) : shapes(&s), name(op.name())
{
}
std::string prefix() const
{
if(name.empty())
return "";
else
return name + ": ";
}
const check_shapes& has(std::size_t n) const
{
assert(shapes != nullptr);
if(shapes->size() != n)
MIGRAPH_THROW(prefix() + "Wrong number of arguments: expected " + std::to_string(n) +
" but given " + std::to_string(shapes->size()));
return *this;
}
const check_shapes& only_dims(std::size_t n) const
{
assert(shapes != nullptr);
if(!shapes->empty())
{
if(shapes->front().lens().size() != n)
MIGRAPH_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
}
return *this;
}
const check_shapes& same_shape() const
{
if(!this->same([](const shape& s) { return s; }))
MIGRAPH_THROW(prefix() + "Shapes do not match");
return *this;
}
const check_shapes& same_type() const
{
if(!this->same([](const shape& s) { return s.type(); }))
MIGRAPH_THROW(prefix() + "Types do not match");
return *this;
}
const check_shapes& same_dims() const
{
if(!this->same([](const shape& s) { return s.lens(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
const check_shapes& same_ndims() const
{
if(!this->same([](const shape& s) { return s.lens().size(); }))
MIGRAPH_THROW(prefix() + "Dimensions do not match");
return *this;
}
template <class F>
bool same(F f) const
{
assert(shapes != nullptr);
if(shapes->empty())
return true;
auto&& key = f(shapes->front());
return this->all_of([&](const shape& s) { return f(s) == key; });
}
template <class Predicate>
bool all_of(Predicate p) const
{
assert(shapes != nullptr);
return std::all_of(shapes->begin(), shapes->end(), p);
}
};
struct not_computable struct not_computable
{ {
argument compute(context&, shape, std::vector<argument>) const argument compute(context&, shape, std::vector<argument>) const
...@@ -112,6 +27,14 @@ struct batch_norm_inference ...@@ -112,6 +27,14 @@ struct batch_norm_inference
std::string name() const { return "batch_norm_inference"; } std::string name() const { return "batch_norm_inference"; }
enum bn_infer_mode_t
{
per_activation,
spatial,
};
bn_infer_mode_t bn_mode = spatial;
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
...@@ -579,20 +502,6 @@ struct outline ...@@ -579,20 +502,6 @@ struct outline
argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; } argument compute(context&, shape, std::vector<argument>) const { return {s, nullptr}; }
}; };
template <class T>
struct check_context
{
std::string name() const { return "check_context"; }
shape compute_shape(std::vector<shape>) const { return {}; }
argument compute(context& ctx, shape, std::vector<argument>) const
{
T* x = any_cast<T>(&ctx);
if(x == nullptr)
MIGRAPH_THROW(std::string("Unexpected context type: ") + ctx.type_id().name());
return {};
}
};
} // namespace migraph } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_PASS_HPP #ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP #define MIGRAPH_GUARD_PASS_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_RANGES_HPP
#include <algorithm>
namespace migraph { namespace migraph {
template <class C, class T> template <class C, class T>
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
...@@ -56,6 +56,8 @@ struct cpu_batch_norm_inference ...@@ -56,6 +56,8 @@ struct cpu_batch_norm_inference
auto image_height = output_shape.lens()[2]; auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3]; auto image_width = output_shape.lens()[3];
if(op.bn_mode == batch_norm_inference::spatial)
{
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)( visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) { [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
...@@ -66,6 +68,22 @@ struct cpu_batch_norm_inference ...@@ -66,6 +68,22 @@ struct cpu_batch_norm_inference
bias(c); bias(c);
}); });
}); });
}
if(op.bn_mode == batch_norm_inference::per_activation)
{
visit_all(output, input, mini_batch_mean, mini_batch_mean, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c, h, w) *
(buffer(n, c, h, w) - mean(c, h, w)) /
std::sqrt(variance(c, h, w) + epsilon) +
bias(c, h, w);
});
});
}
return output; return output;
} }
......
...@@ -249,7 +249,6 @@ struct miopen_apply ...@@ -249,7 +249,6 @@ struct miopen_apply
void apply() void apply()
{ {
prog->insert_instruction(prog->begin(), check_context<context>{});
for(auto it = prog->begin(); it != prog->end(); it++) for(auto it = prog->begin(); it != prog->end(); it++)
{ {
if(it->op.name() == "convolution") if(it->op.name() == "convolution")
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
#include <migraph/gpu/lowering.hpp> #include <migraph/gpu/lowering.hpp>
#include <migraph/gpu/write_literals.hpp> #include <migraph/gpu/write_literals.hpp>
#include <migraph/gpu/context.hpp> #include <migraph/gpu/context.hpp>
#include <migraph/check_context.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
std::vector<pass> target::get_passes(migraph::context&) const std::vector<pass> target::get_passes(migraph::context&) const
{ {
return {lowering{}, write_literals{}}; return {lowering{}, write_literals{}, check_context<context>{}};
} }
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
......
...@@ -106,9 +106,31 @@ if(MIGRAPH_ENABLE_GPU) ...@@ -106,9 +106,31 @@ if(MIGRAPH_ENABLE_GPU)
endforeach() endforeach()
endif() endif()
# Onnx test
add_executable(test_onnx onnx/onnx_test.cpp) add_executable(test_onnx onnx/onnx_test.cpp)
target_link_libraries(test_onnx migraph_onnx) target_link_libraries(test_onnx migraph_onnx)
target_include_directories(test_onnx PUBLIC include) target_include_directories(test_onnx PUBLIC include)
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx) add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests test_onnx) add_dependencies(tests test_onnx)
add_dependencies(check test_onnx) add_dependencies(check test_onnx)
function(test_header NAME HEADER)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp
"#include <${HEADER}>\nint main() {}\n"
)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/header-static-include-${NAME}.cpp
"#include <${HEADER}>\n"
)
add_test_executable(${NAME}
${CMAKE_CURRENT_BINARY_DIR}/header-main-include-${NAME}.cpp
${CMAKE_CURRENT_BINARY_DIR}/header-static-include-${NAME}.cpp
)
endfunction()
file(GLOB HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraph/*.hpp)
foreach(HEADER ${HEADERS})
get_filename_component(BASE_NAME ${HEADER} NAME_WE)
test_header(header_${BASE_NAME} migraph/${BASE_NAME}.hpp)
endforeach()
#ifndef MIGRAPH_GUARD_CONTEXT_HPP #ifndef MIGRAPH_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP #define MIGRAPH_GUARD_CONTEXT_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_OPERAND_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
#ifndef MIGRAPH_GUARD_PASS_HPP #ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP #define MIGRAPH_GUARD_PASS_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string> #include <string>
#include <functional> #include <functional>
#include <memory> #include <memory>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment