Commit 71d850ae authored by Scott Thornton's avatar Scott Thornton
Browse files

Formatting

parent 0c68e1cd
...@@ -289,24 +289,27 @@ struct concat ...@@ -289,24 +289,27 @@ struct concat
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if (inputs.empty()) if(inputs.empty())
{ {
MIGRAPH_THROW("Number of input tensors should exceed 0"); MIGRAPH_THROW("Number of input tensors should exceed 0");
} }
const auto& first_shape_lens = inputs.front().lens(); const auto& first_shape_lens = inputs.front().lens();
const auto& type = inputs.front().type(); const auto& type = inputs.front().type();
for (std::size_t l = 0; l < first_shape_lens.size(); l++) { for(std::size_t l = 0; l < first_shape_lens.size(); l++)
if (l != axis) { {
if (!std::all_of(inputs.begin(), inputs.end(), [&] (auto s) { if(l != axis)
return s.lens()[l] == first_shape_lens[l];})) {
{ if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
MIGRAPH_THROW("Non-axis dimensions should match"); return s.lens()[l] == first_shape_lens[l];
}))
{
MIGRAPH_THROW("Non-axis dimensions should match");
} }
} }
} }
std::size_t new_dim_axis = 0; std::size_t new_dim_axis = 0;
for (const auto& input : inputs) for(const auto& input : inputs)
{ {
const auto& lens = input.lens(); const auto& lens = input.lens();
new_dim_axis += lens[axis]; new_dim_axis += lens[axis];
......
...@@ -282,14 +282,12 @@ struct cpu_contiguous ...@@ -282,14 +282,12 @@ struct cpu_contiguous
} }
}; };
struct cpu_concat struct cpu_concat
{ {
struct tensor_descriptor struct tensor_descriptor
{ {
tensor_descriptor() = default; tensor_descriptor() = default;
tensor_descriptor(const shape& s) tensor_descriptor(const shape& s) : lens(s.lens()), strides(s.strides()) {}
: lens(s.lens()), strides(s.strides()) {}
std::vector<std::size_t> multi(size_t idx) const std::vector<std::size_t> multi(size_t idx) const
{ {
std::size_t sz = strides.size(); std::size_t sz = strides.size();
...@@ -304,12 +302,11 @@ struct cpu_concat ...@@ -304,12 +302,11 @@ struct cpu_concat
} }
size_t linear(std::vector<std::size_t> s) const size_t linear(std::vector<std::size_t> s) const
{ {
//return std::inner_product(s.begin(), s.end(), strides.begin(), 0); // return std::inner_product(s.begin(), s.end(), strides.begin(), 0);
size_t idx = 0; size_t idx = 0;
for(size_t i = 0; i < s.size(); i++) for(size_t i = 0; i < s.size(); i++)
idx += s[i] * strides[i]; idx += s[i] * strides[i];
return idx; return idx;
} }
std::vector<std::size_t> lens; std::vector<std::size_t> lens;
std::vector<std::size_t> strides; std::vector<std::size_t> strides;
...@@ -318,12 +315,13 @@ struct cpu_concat ...@@ -318,12 +315,13 @@ struct cpu_concat
op::concat op; op::concat op;
std::string name() const { return "cpu::concat"; } std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<std::size_t> compute_offsets(const shape& output_shape, const std::vector<argument> args) const std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{ {
std::vector<std::size_t> offsets; std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(),0); std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[op.axis] = 0; offset[op.axis] = 0;
for (const auto& arg : args) for(const auto& arg : args)
{ {
offsets.push_back(output_shape.index(offset)); offsets.push_back(output_shape.index(offset));
offset[op.axis] += arg.get_shape().lens()[op.axis]; offset[op.axis] += arg.get_shape().lens()[op.axis];
...@@ -335,17 +333,17 @@ struct cpu_concat ...@@ -335,17 +333,17 @@ struct cpu_concat
argument result{output_shape}; argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for (std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
std::cout << argl << std::endl; std::cout << argl << std::endl;
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto* outptr = output.data() + coffsets[l]; auto* outptr = output.data() + coffsets[l];
const auto* inptr = input.data(); const auto* inptr = input.data();
tensor_descriptor desc_input(input.get_shape()); tensor_descriptor desc_input(input.get_shape());
tensor_descriptor desc_output(output.get_shape()); tensor_descriptor desc_output(output.get_shape());
for (std::size_t i = 0; i < nelements; i++) for(std::size_t i = 0; i < nelements; i++)
{ {
outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; outptr[desc_output.linear(desc_input.multi(i))] = inptr[i];
} }
...@@ -630,7 +628,7 @@ struct cpu_apply ...@@ -630,7 +628,7 @@ struct cpu_apply
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
apply_map["concat"] = extend_op<cpu_concat, op::concat>(); apply_map["concat"] = extend_op<cpu_concat, op::concat>();
apply_map["identity"] = simple_op<cpu_unary<identity_op>>(); apply_map["identity"] = simple_op<cpu_unary<identity_op>>();
apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>(); apply_map["tanh"] = simple_op<cpu_unary<tanh_op>>();
......
...@@ -50,7 +50,7 @@ void slice_test() ...@@ -50,7 +50,7 @@ void slice_test()
void concat_test() void concat_test()
{ {
migraph::program p; migraph::program p;
std::size_t axis = 1; std::size_t axis = 1;
std::vector<int> data0 = {0, 1, 5, 6}; std::vector<int> data0 = {0, 1, 5, 6};
std::vector<int> data1 = {2, 3, 4, 5, 6, 7}; std::vector<int> data1 = {2, 3, 4, 5, 6, 7};
migraph::shape s0{migraph::shape::int32_type, {2, 2}}; migraph::shape s0{migraph::shape::int32_type, {2, 2}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment