Commit b52e7149 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into onnx_fixes

parents 23ef67a7 abe4092b
......@@ -73,8 +73,6 @@ struct program
argument eval(parameter_map params) const;
friend std::ostream& operator<<(std::ostream& os, const program& p);
bool has_instruction(instruction_ref ins) const;
instruction_ref begin();
......@@ -84,6 +82,10 @@ struct program
void compile(const target& t);
friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
private:
std::unique_ptr<program_impl> impl;
};
......
......@@ -66,7 +66,7 @@ inline std::string remove_prefix(std::string s, std::string prefix)
}
template <class Range>
inline std::string to_string(const Range& r)
inline std::string to_string_range(const Range& r)
{
std::stringstream ss;
if(!r.empty())
......@@ -77,6 +77,14 @@ inline std::string to_string(const Range& r)
return ss.str();
}
template <class T>
inline std::string to_string(const T& x)
{
std::stringstream ss;
ss << x;
return ss.str();
}
} // namespace migraph
#endif
......@@ -2,6 +2,7 @@
#include <migraph/stringutils.hpp>
#include <migraph/instruction.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
namespace migraph {
......@@ -190,6 +191,8 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
return result;
}
bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }
std::ostream& operator<<(std::ostream& os, const program& p)
{
std::unordered_map<const instruction*, std::string> names;
......
......@@ -126,8 +126,8 @@ bool operator!=(const shape& x, const shape& y) { return !(x == y); }
std::ostream& operator<<(std::ostream& os, const shape& x)
{
os << x.type_string() << ", ";
os << "{" << to_string(x.lens()) << "}, ";
os << "{" << to_string(x.strides()) << "}";
os << "{" << to_string_range(x.lens()) << "}, ";
os << "{" << to_string_range(x.strides()) << "}";
return os;
}
......
......@@ -24,11 +24,11 @@ T zero(const T&)
// args[1] -> mini batch mean
// args[2] -> mini batch variance
// args[3] -> gamma
// args[4] -> beta
// args[4] -> bias
//
// The equation to compute batch norm for inference is:
//
// output[i] = beta + gamma * (input[i] + mean) / sqrt(variance + epsilon)
// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
//
// the input data format should be nchw
//
......@@ -46,17 +46,26 @@ struct cpu_batch_norm_inference
double epsilon = op.epsilon;
auto input = args[0];
auto mini_batch_mean = args[1].at<float>();
auto mini_batch_variance = args[2].at<float>();
auto gamma = args[3].at<float>();
auto beta = args[4].at<float>();
visit_all(output, input)([&](auto result, auto buffer) {
std::transform(buffer.begin(), buffer.end(), result.begin(), [&](auto x) {
return gamma * (x - mini_batch_mean) / std::sqrt(mini_batch_variance + epsilon) +
beta;
auto mini_batch_mean = args[1];
auto mini_batch_variance = args[2];
auto arg_gamma = args[3];
auto arg_bias = args[4];
auto num_batch = output_shape.lens()[0];
auto num_channels = output_shape.lens()[1];
auto image_height = output_shape.lens()[2];
auto image_width = output_shape.lens()[3];
visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
[&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
dfor(num_batch, num_channels, image_height, image_width)(
[&](std::size_t n, std::size_t c, std::size_t h, std::size_t w) {
result(n, c, h, w) = gamma(c) * (buffer(n, c, h, w) - mean(c)) /
std::sqrt(variance(c) + epsilon) +
bias(c);
});
});
});
return output;
}
......
......@@ -9,19 +9,40 @@
void batch_norm_inference_test()
{
migraph::program p;
migraph::shape s{migraph::shape::float_type, {4}};
auto x = p.add_literal(migraph::literal{s, {1, 2, 3, 4}});
auto gamma = p.add_literal(migraph::literal{s, {1}});
auto beta = p.add_literal(migraph::literal{s, {0}});
auto mean = p.add_literal(migraph::literal{s, {0}});
auto variance = p.add_literal(migraph::literal{s, {1}});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, gamma, beta);
const size_t width = 2, height = 2, channels = 4, batches = 2;
const float x_val = 8.0f, mean_val = 2.0f, variance_val = 4.0f, scale_val = 2.0f,
bias_val = 1.0f;
const float output_val = scale_val * (x_val - mean_val) / (std::sqrt(variance_val)) + bias_val;
migraph::shape s{migraph::shape::float_type, {batches, channels, height, width}};
migraph::shape vars{migraph::shape::float_type, {channels}};
std::vector<float> x_data(width * height * channels * batches);
std::vector<float> scale_data(channels);
std::vector<float> bias_data(channels);
std::vector<float> mean_data(channels);
std::vector<float> variance_data(channels);
std::fill(x_data.begin(), x_data.end(), x_val);
std::fill(mean_data.begin(), mean_data.end(), mean_val);
std::fill(variance_data.begin(), variance_data.end(), variance_val);
std::fill(scale_data.begin(), scale_data.end(), scale_val);
std::fill(bias_data.begin(), bias_data.end(), bias_val);
auto x = p.add_literal(migraph::literal{s, x_data});
auto scale = p.add_literal(migraph::literal{vars, scale_data});
auto bias = p.add_literal(migraph::literal{vars, bias_data});
auto mean = p.add_literal(migraph::literal{vars, mean_data});
auto variance = p.add_literal(migraph::literal{vars, variance_data});
p.add_instruction(migraph::batch_norm_inference{}, x, mean, variance, scale, bias);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> result_vector(4);
std::vector<float> result_vector(width * height * channels * batches);
std::vector<float> gold(width * height * channels * batches);
std::fill(gold.begin(), gold.end(), output_val);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
1 / (1 + 1.0e-6), 2 / (1 + 1.0e-6), 3 / (1 + 1.0e-6), 4 / (1 + 1.0e-6)};
EXPECT(test::verify_range(result_vector, gold));
}
......
......@@ -48,9 +48,9 @@ struct expression
decltype(auto) value() const { return Operator::call(lhs, rhs); };
};
// TODO: Remove rvalue references
template <class T, class U, class Operator>
expression<typename std::decay<T>::type, typename std::decay<U>::type, Operator>
make_expression(T&& rhs, U&& lhs, Operator)
expression<T, U, Operator> make_expression(T&& rhs, U&& lhs, Operator)
{
return {std::forward<T>(rhs), std::forward<U>(lhs)};
}
......@@ -58,10 +58,11 @@ make_expression(T&& rhs, U&& lhs, Operator)
template <class T>
struct lhs_expression;
// TODO: Remove rvalue reference
template <class T>
lhs_expression<typename std::decay<T>::type> make_lhs_expression(T&& lhs)
lhs_expression<T> make_lhs_expression(T&& lhs)
{
return lhs_expression<typename std::decay<T>::type>{std::forward<T>(lhs)};
return lhs_expression<T>{std::forward<T>(lhs)};
}
template <class T>
......
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
#include <basic_ops.hpp>
migraph::program create_program()
{
migraph::program p;
auto x = p.add_parameter("x", {migraph::shape::int64_type});
auto y = p.add_parameter("y", {migraph::shape::int64_type});
auto sum = p.add_instruction(sum_op{}, x, y);
auto one = p.add_literal(1);
p.add_instruction(sum_op{}, sum, one);
return p;
}
void program_equality()
{
migraph::program x = create_program();
migraph::program y = create_program();
EXPECT(x == y);
}
int main() { program_equality(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment