"src/vscode:/vscode.git/clone" did not exist on "32ddc1acb4dcf9c07a1c4418278278af120d02d0"
Commit b52e7149 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into onnx_fixes

parents 23ef67a7 abe4092b
...@@ -73,8 +73,6 @@ struct program ...@@ -73,8 +73,6 @@ struct program
argument eval(parameter_map params) const; argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
instruction_ref begin(); instruction_ref begin();
...@@ -84,6 +82,10 @@ struct program ...@@ -84,6 +82,10 @@ struct program
void compile(const target& t); void compile(const target& t);
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private: private:
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
}; };
......
...@@ -66,7 +66,7 @@ inline std::string remove_prefix(std::string s, std::string prefix) ...@@ -66,7 +66,7 @@ inline std::string remove_prefix(std::string s, std::string prefix)
} }
template <class Range> template <class Range>
inline std::string to_string(const Range& r) inline std::string to_string_range(const Range& r)
{ {
std::stringstream ss; std::stringstream ss;
if(!r.empty()) if(!r.empty())
...@@ -77,6 +77,14 @@ inline std::string to_string(const Range& r) ...@@ -77,6 +77,14 @@ inline std::string to_string(const Range& r)
return ss.str(); return ss.str();
} }
template <class T>
inline std::string to_string(const T& x)
{
std::stringstream ss;
ss << x;
return ss.str();
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraph/stringutils.hpp> #include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <iostream> #include <iostream>
#include <sstream>
#include <algorithm> #include <algorithm>
namespace migraph { namespace migraph {
...@@ -190,6 +191,8 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -190,6 +191,8 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return result; return result;
} }
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p) std::ostream& operator<<(std::ostream& os, const program& p)
{ {
std::unordered_map<const instruction*, std::string> names; std::unordered_map<const instruction*, std::string> names;
......
...@@ -126,8 +126,8 @@ bool operator!=(const shape& x, const shape& y) { return !(x == y); } ...@@ -126,8 +126,8 @@ bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x) std::ostream& operator<<(std::ostream& os, const shape& x)
{ {
os << x.type_string() << ", "; os << x.type_string() << ", ";
os << "{" << to_string(x.lens()) << "}, "; os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string(x.strides()) << "}"; os << "{" << to_string_range(x.strides()) << "}";
return os; return os;
} }
......
...@@ -24,11 +24,11 @@ T zero(const T&) ...@@ -24,11 +24,11 @@ T zero(const T&)
// args[1] -> mini batch mean // args[1] -> mini batch mean
// args[2] -> mini batch variance // args[2] -> mini batch variance
// args[3] -> gamma // args[3] -> gamma
// args[4] -> beta // args[4] -> bias
// //
// The equation to compute batch norm for inference is: // The equation to compute batch norm for inference is:
// //
// output[i] = beta + gamma * (input[i] + mean) / sqrt(variance + epsilon) // output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
// //
// the input data format should be nchw // the input data format should be nchw
// //
...@@ -46,17 +46,26 @@ struct cpu_batch_norm_inference ...@@ -46,17 +46,26 @@ struct cpu_batch_norm_inference
double epsilon = op.epsilon; double epsilon = op.epsilon;
auto input = args[0]; auto input = args[0];
auto mini_batch_mean = args[1].at<float>(); auto mini_batch_mean = args[1];
auto mini_batch_variance = args[2].at<float>(); auto mini_batch_variance = args[2];
auto gamma = args[3].at<float>(); auto arg_gamma = args[3];
auto beta = args[4].at<float>(); auto arg_bias = args[4];
visit_all(output, input)([&](auto result, auto buffer) { auto num_batch = output_shape.lens()[0];
std::transform(buffer.begin(), buffer.end(), result.begin(), [&](auto x) { auto num_channels = output_shape.lens()[1];
return gamma * (x - mini_batch_mean) / std::sqrt(mini_batch_variance + epsilon) + auto image_height = output_shape.lens()[2];
beta; auto image_width = output_shape.lens()[3];
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
std::sqrt(variance(c) + epsilon) +
bias(c);
});
}); });
});
return output; return output;
} }
......
...@@ -9,19 +9,40 @@ ...@@ -9,19 +9,40 @@
void batch_norm_inference_test() void batch_norm_inference_test()
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {4}}; const size_t width = 2, height = 2, channels = 4, batches = 2;
auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}}); const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f,
auto gamma = p.add_literal(migraph::literal{s, {1}}); bias_val = 1.0f;
auto beta = p.add_literal(migraph::literal{s, {0}}); const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
auto mean = p.add_literal(migraph::literal{s, {0}});
auto variance = p.add_literal(migraph::literal{s, {1}}); migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta); migraph::shape vars{migraph::shape::float_type, {channels}};
std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels);
std::vector<float> mean_data(channels);
std::vector<float> variance_data(channels);
std::fill(x_data.begin(), x_data.end(), x_val);
std::fill(mean_data.begin(), mean_data.end(), mean_val);
std::fill(variance_data.begin(), variance_data.end(), variance_val);
std::fill(scale_data.begin(), scale_data.end(), scale_val);
std::fill(bias_data.begin(), bias_data.end(), bias_val);
auto x = p.add_literal(migraph::literal{s, x_data});
auto scale = p.add_literal(migraph::literal{vars, scale_data});
auto bias = p.add_literal(migraph::literal{vars, bias_data});
auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> result_vector(4);
std::vector<float> result_vector(width * height * channels * batches);
std::vector<float> gold(width * height * channels * batches);
std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
EXPECT(test::verify_range(result_vector, gold)); EXPECT(test::verify_range(result_vector, gold));
} }
......
...@@ -48,9 +48,9 @@ struct expression ...@@ -48,9 +48,9 @@ struct expression
decltype(auto) value() const { return Operator::call(lhs, rhs); }; decltype(auto) value() const { return Operator::call(lhs, rhs); };
}; };
// TODO: Remove rvalue references
template <class T, class U, class Operator> template <class T, class U, class Operator>
expression<typename std::decay<T>::type, typename std::decay<U>::type, Operator> expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
make_expression(T&& rhs, U&& lhs, Operator)
{ {
return {std::forward<T>(rhs), std::forward<U>(lhs)}; return {std::forward<T>(rhs), std::forward<U>(lhs)};
} }
...@@ -58,10 +58,11 @@ make_expression(T&& rhs, U&& lhs, Operator) ...@@ -58,10 +58,11 @@ make_expression(T&& rhs, U&& lhs, Operator)
template <class T> template <class T>
struct lhs_expression; struct lhs_expression;
// TODO: Remove rvalue reference
template <class T> template <class T>
lhs_expression<typename std::decay<T>::type> make_lhs_expression(T&& lhs) lhs_expression<T> make_lhs_expression(T&& lhs)
{ {
return lhs_expression<typename std::decay<T>::type>{std::forward<T>(lhs)}; return lhs_expression<T>{std::forward<T>(lhs)};
} }
template <class T> template <class T>
......
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraph::program create_program()
{
migraph::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
void program_equality()
{
migraph::program x = create_program();
migraph::program y = create_program();
EXPECT(x == y);
}
int main() { program_equality(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment