Commit 3272b22e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 94e3a2e4
...@@ -51,8 +51,7 @@ TEST_CASE(if_pl_test) ...@@ -51,8 +51,7 @@ TEST_CASE(if_pl_test)
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths(); auto lens = output.get_shape().lengths();
auto elem_num = auto elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
float* data_ptr = reinterpret_cast<float*>(output.data()); float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<float> ret(data_ptr, data_ptr + elem_num); std::vector<float> ret(data_ptr, data_ptr + elem_num);
...@@ -100,8 +99,7 @@ TEST_CASE(loop_test) ...@@ -100,8 +99,7 @@ TEST_CASE(loop_test)
auto outputs = p.eval(pp); auto outputs = p.eval(pp);
auto output = outputs[0]; auto output = outputs[0];
auto lens = output.get_shape().lengths(); auto lens = output.get_shape().lengths();
auto elem_num = auto elem_num = std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
std::accumulate(lens.begin(), lens.end(), 1, std::multiplies<int>());
float* data_ptr = reinterpret_cast<float*>(output.data()); float* data_ptr = reinterpret_cast<float*>(output.data());
std::vector<std::vector<float>> ret; std::vector<std::vector<float>> ret;
ret.push_back({data_ptr, data_ptr + elem_num}); ret.push_back({data_ptr, data_ptr + elem_num});
......
...@@ -749,8 +749,7 @@ TEST_CASE(concat_test) ...@@ -749,8 +749,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
} }
{ {
...@@ -774,8 +773,7 @@ TEST_CASE(concat_test) ...@@ -774,8 +773,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({2, 6})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({6, 1})));
} }
{ {
...@@ -799,8 +797,7 @@ TEST_CASE(concat_test) ...@@ -799,8 +797,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
} }
{ {
...@@ -824,8 +821,7 @@ TEST_CASE(concat_test) ...@@ -824,8 +821,7 @@ TEST_CASE(concat_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2}))); EXPECT(migraphx::verify_range(result.get_shape().lens(), std::vector<int>({6, 2})));
EXPECT( EXPECT(migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
migraphx::verify_range(result.get_shape().strides(), std::vector<int>({2, 1})));
} }
} }
......
...@@ -82,8 +82,7 @@ struct stream_free_op ...@@ -82,8 +82,7 @@ struct stream_free_op
struct wait_event struct wait_event
{ {
std::shared_ptr<std::vector<int>> wait_for = std::shared_ptr<std::vector<int>> wait_for = std::make_shared<std::vector<int>>();
std::make_shared<std::vector<int>>();
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -104,8 +103,7 @@ struct wait_event ...@@ -104,8 +103,7 @@ struct wait_event
using instruction_map = std::unordered_map<migraphx::instruction_ref, int>; using instruction_map = std::unordered_map<migraphx::instruction_ref, int>;
using int_map = std::unordered_map<int, int>; using int_map = std::unordered_map<int, int>;
using wait_map = using wait_map = std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<int>>>;
std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<int>>>;
struct schedule_model_test struct schedule_model_test
{ {
...@@ -211,10 +209,7 @@ std::vector<T> unique(std::vector<T> x) ...@@ -211,10 +209,7 @@ std::vector<T> unique(std::vector<T> x)
return x; return x;
} }
std::vector<int> get_wait_for(std::vector<int> wait_for) std::vector<int> get_wait_for(std::vector<int> wait_for) { return unique(std::move(wait_for)); }
{
return unique(std::move(wait_for));
}
std::vector<int> get_wait_for(int wait_on, std::vector<int> wait_for) std::vector<int> get_wait_for(int wait_on, std::vector<int> wait_for)
{ {
......
...@@ -337,15 +337,12 @@ TEST_CASE(test_shape4_nonpacked) ...@@ -337,15 +337,12 @@ TEST_CASE(test_shape4_nonpacked)
std::array<int, 4> offsets = {{5, 10, 0, 6}}; std::array<int, 4> offsets = {{5, 10, 0, 6}};
std::array<int, 4> adj_lens = {{0, 0, 0, 0}}; std::array<int, 4> adj_lens = {{0, 0, 0, 0}};
std::transform( std::transform(lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<int>());
lens.begin(), lens.end(), offsets.begin(), adj_lens.begin(), std::plus<int>());
// adj_lens should be: { 105, 42, 8, 14 } // adj_lens should be: { 105, 42, 8, 14 }
std::vector<int> strides(4); std::vector<int> strides(4);
strides.back() = 1; strides.back() = 1;
std::partial_sum(adj_lens.rbegin(), std::partial_sum(
adj_lens.rend() - 1, adj_lens.rbegin(), adj_lens.rend() - 1, strides.rbegin() + 1, std::multiplies<int>());
strides.rbegin() + 1,
std::multiplies<int>());
migraphx::shape s{migraphx::shape::float_type, lens, strides}; migraphx::shape s{migraphx::shape::float_type, lens, strides};
EXPECT(not s.standard()); EXPECT(not s.standard());
......
...@@ -20,8 +20,7 @@ ...@@ -20,8 +20,7 @@
#include "test.hpp" #include "test.hpp"
migraphx::program migraphx::program parse_tf(const std::string& name,
parse_tf(const std::string& name,
bool is_nhwc, bool is_nhwc,
const std::unordered_map<std::string, std::vector<int>>& dim_params = {}, const std::unordered_map<std::string, std::vector<int>>& dim_params = {},
const std::vector<std::string>& output_node_names = {}) const std::vector<std::string>& output_node_names = {})
......
...@@ -6,10 +6,8 @@ ...@@ -6,10 +6,8 @@
struct test_conv_bn_add : verify_program<test_conv_bn_add> struct test_conv_bn_add : verify_program<test_conv_bn_add>
{ {
static migraphx::instruction_ref add_bn(migraphx::module& m, static migraphx::instruction_ref
migraphx::instruction_ref x, add_bn(migraphx::module& m, migraphx::instruction_ref x, int channels, int seed = 1)
int channels,
int seed = 1)
{ {
migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment