Commit 96a3a8be authored by wsttiger's avatar wsttiger
Browse files

Refactored compute_offsets in concat

parent 82a6c6a5
...@@ -318,6 +318,19 @@ struct concat ...@@ -318,6 +318,19 @@ struct concat
{ {
std::size_t axis = 0; std::size_t axis = 0;
std::string name() const { return "concat"; } std::string name() const { return "concat"; }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[axis] += arg.get_shape().lens()[axis];
}
return offsets;
}
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
if(inputs.empty()) if(inputs.empty())
......
...@@ -287,23 +287,10 @@ struct cpu_concat ...@@ -287,23 +287,10 @@ struct cpu_concat
op::concat op; op::concat op;
std::string name() const { return "cpu::concat"; } std::string name() const { return "cpu::concat"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[op.axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[op.axis] += arg.get_shape().lens()[op.axis];
}
return offsets;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
std::vector<std::size_t> coffsets = compute_offsets(output_shape, args); std::vector<std::size_t> coffsets = op.compute_offsets(output_shape, args);
for(std::size_t l = 0; l < args.size(); l++) for(std::size_t l = 0; l < args.size(); l++)
{ {
auto argl = args[l]; auto argl = args[l];
......
...@@ -14,24 +14,10 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const ...@@ -14,24 +14,10 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
std::vector<std::size_t> hip_concat::compute_offsets(const shape& output_shape,
const std::vector<argument> args) const
{
std::vector<std::size_t> offsets;
std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
offset[op.axis] = 0;
for(const auto& arg : args)
{
offsets.push_back(output_shape.index(offset));
offset[op.axis] += arg.get_shape().lens()[op.axis];
}
return offsets;
}
argument argument
hip_concat::compute(context&, const shape& output_shape, const std::vector<argument>& args) const hip_concat::compute(context&, const shape& output_shape, const std::vector<argument>& args) const
{ {
std::vector<std::size_t> offsets = compute_offsets(output_shape, args); std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets); return device::concat(output_shape, args, offsets);
} }
......
...@@ -26,8 +26,6 @@ struct hip_concat ...@@ -26,8 +26,6 @@ struct hip_concat
std::string name() const { return "gpu::concat"; } std::string name() const { return "gpu::concat"; }
shape compute_shape(std::vector<shape> inputs) const; shape compute_shape(std::vector<shape> inputs) const;
std::vector<std::size_t> compute_offsets(const shape& output_shape,
const std::vector<argument> args) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment