Commit 10b5fe9a authored by Khalique's avatar Khalique
Browse files

formatting

parent b98a28cc
...@@ -157,9 +157,7 @@ bool shape::scalar() const ...@@ -157,9 +157,7 @@ bool shape::scalar() const
{ {
assert(this->lens().size() == this->strides().size()); assert(this->lens().size() == this->strides().size());
// if any stride > 0, then accumulate will return false // if any stride > 0, then accumulate will return false
return std::accumulate(this->strides().begin(), return std::accumulate(this->strides().begin(), this->strides().end(), std::size_t(0)) == 0;
this->strides().end(),
std::size_t(0)) == 0;
} }
bool shape::standard() const { return impl->m_standard; } bool shape::standard() const { return impl->m_standard; }
......
...@@ -333,7 +333,8 @@ nary(hipStream_t stream, const argument& result, const argument& arg1, const arg ...@@ -333,7 +333,8 @@ nary(hipStream_t stream, const argument& result, const argument& arg1, const arg
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and not arg2.get_shape().scalar()) if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
......
...@@ -543,36 +543,46 @@ void imagescaler_test() ...@@ -543,36 +543,46 @@ void imagescaler_test()
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}}; migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}};
auto img = p.add_literal(migraph::literal{s, { auto img = p.add_literal(migraph::literal{s,
0.2, 0.3, {0.2,
0.5, 0.4, 0.3,
0.5,
0.7, 0.8, 0.4,
0.1, 0.9,
0.7,
0.15, 0.25, 0.8,
0.35, 0.45 0.1,
}}); 0.9,
auto scale_val = p.add_literal(2.f);
0.15,
0.25,
0.35,
0.45}});
auto scale_val = p.add_literal(2.f);
auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val); auto scaled_tensor = p.add_instruction(migraph::op::scalar{s}, scale_val);
auto img_scaled = p.add_instruction(migraph::op::mul{}, img, scaled_tensor); auto img_scaled = p.add_instruction(migraph::op::mul{}, img, scaled_tensor);
auto bias_vals = p.add_literal(migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}}); auto bias_vals = p.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {3}}, {0.01, 0.02, 0.03}});
auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals); auto bias_bcast = p.add_instruction(migraph::op::broadcast{1, s}, bias_vals);
p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); p.add_instruction(migraph::op::add{}, img_scaled, bias_bcast);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = { std::vector<float> gold = {0.41,
0.41, 0.61, 0.61,
1.01, 0.81, 1.01,
0.81,
1.42, 1.62,
0.22, 1.82, 1.42,
1.62,
0.33, 0.53, 0.22,
0.73, 0.93 1.82,
};
0.33,
0.53,
0.73,
0.93};
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
......
...@@ -194,8 +194,8 @@ struct test_scale ...@@ -194,8 +194,8 @@ struct test_scale
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
auto x = p.add_parameter("x", s); auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", migraph::shape::float_type); auto y = p.add_parameter("y", migraph::shape::float_type);
auto scale = p.add_instruction(migraph::op::scalar{s}, y); auto scale = p.add_instruction(migraph::op::scalar{s}, y);
p.add_instruction(migraph::op::mul{}, x, scale); p.add_instruction(migraph::op::mul{}, x, scale);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment