Commit bfe4a900 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add a multi() method to shape for simplify code.

parent 4479fc3c
...@@ -99,6 +99,8 @@ struct shape ...@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index /// Map element index to space index
std::size_t index(std::size_t i) const; std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending /// Returns true is the shape has been transposed. That is the strides are not in descending
......
...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const ...@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return result; return result;
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const
{
assert(this->standard());
std::vector<std::size_t> indices(lens().size());
std::transform(strides().begin(),
strides().end(),
lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
return indices;
}
bool shape::packed() const { return this->elements() == this->element_space(); } bool shape::packed() const { return this->elements() == this->element_space(); }
bool shape::transposed() const bool shape::transposed() const
......
...@@ -529,22 +529,6 @@ struct cpu_softmax ...@@ -529,22 +529,6 @@ struct cpu_softmax
std::string name() const { return "cpu::softmax"; } std::string name() const { return "cpu::softmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -559,8 +543,7 @@ struct cpu_softmax ...@@ -559,8 +543,7 @@ struct cpu_softmax
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto idx = this->compute_batch_indices(i, batch_shape); auto idx = batch_shape.multi(i);
for(size_t j = 0; j < n_dims; ++j) for(size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
...@@ -604,22 +587,6 @@ struct cpu_logsoftmax ...@@ -604,22 +587,6 @@ struct cpu_logsoftmax
std::string name() const { return "cpu::logsoftmax"; } std::string name() const { return "cpu::logsoftmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -637,7 +604,7 @@ struct cpu_logsoftmax ...@@ -637,7 +604,7 @@ struct cpu_logsoftmax
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto idx = this->compute_batch_indices(i, batch_shape); auto idx = batch_shape.multi(i);
for(size_t j = 0; j < n_dims; ++j) for(size_t j = 0; j < n_dims; ++j)
{ {
idx[op.axis] = j; idx[op.axis] = j;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment