"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8fcf177ed9921ff8513ad7e1adfd956520c5242b"
Commit 75831d9e authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Allow specify problem 'axes' through command line argument

parent 8e71cad0
...@@ -37,12 +37,30 @@ using S = ck::Sequence<Is...>; ...@@ -37,12 +37,30 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename RandomAccessRange>
inline bool is_valid_axes(RandomAccessRange axes)
{
using std::empty;
if(empty(axes))
{
return false;
}
using std::begin, std::end;
std::vector<std::size_t> copy(begin(axes), end(axes));
std::sort(begin(copy), end(copy));
const auto last = std::unique(begin(copy), end(copy));
return (last == end(copy)) && (*begin(copy) == 0) && (*std::prev(last) == size(axes) - 1);
}
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{ {
constexpr int num_execution_config_args = 2; constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 8; constexpr int num_problem_args = 8;
assert(num_problem_args == problem.shape.size() + problem.axes.size()); assert(num_problem_args == size(problem.shape) + size(problem.axes));
if(argc == 1) if(argc == 1)
{ {
...@@ -59,15 +77,25 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob ...@@ -59,15 +77,25 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
config.time_kernel = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[2]);
// read shape // read shape
for(std::size_t idx = 0; idx < problem.shape.size(); ++idx) for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
{ {
problem.shape[idx] = std::stoi(argv[idx + 3]); problem.shape[idx] = std::stoi(argv[idx + 3]);
} }
// read axes // read axes
for(std::size_t idx = 0; idx < problem.axes.size(); ++idx) for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
{
problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]);
}
if(!is_valid_axes(problem.axes))
{ {
problem.axes[idx] = std::stoi(argv[idx + problem.shape.size() + 3]); std::cerr << "invalid axes: ";
std::copy(begin(problem.axes),
end(problem.axes),
std::ostream_iterator<std::size_t>(std::cerr, " "));
std::cerr << std::endl;
return false;
} }
} }
else else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment