"include/ck/utility/get_id.hpp" did not exist on "21f7e9f103231b01889ff30ed9f016fc89d3a669"
Commit f8fa90bd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify code.

parent 9a591bbb
...@@ -651,21 +651,6 @@ struct cpu_argmax ...@@ -651,21 +651,6 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; } std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -677,7 +662,7 @@ struct cpu_argmax ...@@ -677,7 +662,7 @@ struct cpu_argmax
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = batch_shape.multi(i);
auto max_val = input[i]; auto max_val = input[i];
int64_t max_index = 0; int64_t max_index = 0;
for(size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
...@@ -712,21 +697,6 @@ struct cpu_argmin ...@@ -712,21 +697,6 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; } std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -738,7 +708,7 @@ struct cpu_argmin ...@@ -738,7 +708,7 @@ struct cpu_argmin
result.visit([&](auto output) { result.visit([&](auto output) {
args[0].visit([&](auto input) { args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) { par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape); auto data_idx = batch_shape.multi(i);
auto min_val = input[i]; auto min_val = input[i];
int64_t min_index = 0; int64_t min_index = 0;
for(size_t j = 1; j < batch_item_num; ++j) for(size_t j = 1; j < batch_item_num; ++j)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment