"docs/vscode:/vscode.git/clone" did not exist on "ac6aee81002701454194fcc088bb5839e22dbea3"
Commit 4443608d authored by wsttiger's avatar wsttiger
Browse files

Fixed for Paul

parent 78c21ef5
...@@ -284,34 +284,6 @@ struct cpu_contiguous ...@@ -284,34 +284,6 @@ struct cpu_contiguous
struct cpu_concat struct cpu_concat
{ {
struct tensor_descriptor
{
tensor_descriptor() = default;
tensor_descriptor(const shape& s) : lens(s.lens()), strides(s.strides()) {}
std::vector<std::size_t> multi(size_t idx) const
{
std::size_t sz = strides.size();
std::vector<std::size_t> result(sz);
size_t tidx = idx;
for(size_t is = 0; is < sz; is++)
{
result[is] = tidx / strides[is];
tidx = tidx % strides[is];
}
return result;
}
size_t linear(std::vector<std::size_t> s) const
{
// return std::inner_product(s.begin(), s.end(), strides.begin(), 0);
size_t idx = 0;
for(size_t i = 0; i < s.size(); i++)
idx += s[i] * strides[i];
return idx;
}
std::vector<std::size_t> lens;
std::vector<std::size_t> strides;
};
op::concat op; op::concat op;
std::string name() const { return "cpu::concat"; } std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
...@@ -330,7 +302,6 @@ struct cpu_concat ...@@ -330,7 +302,6 @@ struct cpu_concat
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
...@@ -338,13 +309,10 @@ struct cpu_concat ...@@ -338,13 +309,10 @@ struct cpu_concat
auto argl = args[l]; auto argl = args[l];
std::size_t nelements = argl.get_shape().elements(); std::size_t nelements = argl.get_shape().elements();
visit_all(result, argl)([&](auto output, auto input) { visit_all(result, argl)([&](auto output, auto input) {
auto* outptr = output.data() + coffsets[l]; auto slice_shape = shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
const auto* inptr = input.data(); auto slice = make_view(slice_shape, output.data()+coffsets[l]);
tensor_descriptor desc_input(input.get_shape()); for(std::size_t i = 0; i < nelements; i++) {
tensor_descriptor desc_output(output.get_shape()); slice[i] = input[i];
for(std::size_t i = 0; i < nelements; i++)
{
outptr[desc_output.linear(desc_input.multi(i))] = inptr[i];
} }
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment