"...resnet50_tensorflow.git" did not exist on "723e053b2f7eb7fa87642198a6e1f5414da87fca"
Commit 632bfff0 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Specify problem in each examples

parent 1989df75
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float;
struct ExecutionConfig final struct ExecutionConfig final
{ {
...@@ -31,8 +32,18 @@ struct ExecutionConfig final ...@@ -31,8 +32,18 @@ struct ExecutionConfig final
struct Problem final struct Problem final
{ {
std::array<std::size_t, 4> shape = {4, 8, 16, 32}; using Shape = std::array<std::size_t, 3>;
std::array<std::size_t, 4> axes = {0, 1, 3, 2}; using Axes = Shape;
Problem() = delete;
explicit Problem(const Shape& default_shape, const Axes& default_axes)
: shape(default_shape), axes(default_axes)
{
}
Shape shape;
Axes axes;
}; };
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -207,7 +218,7 @@ is_valid_axes(const Axes& axes) ...@@ -207,7 +218,7 @@ is_valid_axes(const Axes& axes)
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem) inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{ {
constexpr int num_execution_config_args = 2; constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 8; constexpr int num_problem_args = 3 + 3;
if(!(num_problem_args == size(problem.shape) + size(problem.axes))) if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
{ {
...@@ -231,13 +242,14 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob ...@@ -231,13 +242,14 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
// read shape // read shape
for(std::size_t idx = 0; idx < size(problem.shape); ++idx) for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
{ {
problem.shape[idx] = std::stoi(argv[idx + 3]); problem.shape[idx] = std::stoi(argv[idx + (1 + num_execution_config_args)]);
} }
// read axes // read axes
for(std::size_t idx = 0; idx < size(problem.axes); ++idx) for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
{ {
problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]); problem.axes[idx] =
std::stoi(argv[idx + (1 + num_execution_config_args + size(problem.shape))]);
} }
if(!is_valid_axes(problem.axes)) if(!is_valid_axes(problem.axes))
...@@ -254,8 +266,8 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob ...@@ -254,8 +266,8 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
{ {
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl << "arg2: time kernel (0=no, 1=yes)" << std::endl
<< "arg3 ~ arg6: shape for 4D tensor" << std::endl << "arg3 ~ arg5: shape for 3D tensor" << std::endl
<< "arg7 ~ arg10: axes to permute" << std::endl; << "arg6 ~ arg8: axes to permute" << std::endl;
return false; return false;
} }
...@@ -369,7 +381,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -369,7 +381,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
} }
using std::size; using std::size;
if(!(is_valid_axes(axes) && size(axes) == 4)) if(!(is_valid_axes(axes) && size(axes) == 3))
{ {
return false; return false;
} }
...@@ -377,7 +389,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -377,7 +389,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> && static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>); detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(!(size(shape) == 4 && size(transposed_shape) == 4)) if(!(size(shape) == 3 && size(transposed_shape) == 3))
{ {
return false; return false;
} }
...@@ -394,7 +406,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -394,7 +406,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
} }
} }
std::array<std::size_t, 4> indices{}; std::array<std::size_t, 3> indices{};
if(!is_valid_indices(shape, indices)) if(!is_valid_indices(shape, indices))
{ {
return false; return false;
...@@ -403,8 +415,8 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D ...@@ -403,8 +415,8 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
do do
{ {
Dest output = 0; Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3])); functor(output, src(indices[0], indices[1], indices[2]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output; dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
} while(advance_indices(shape, indices)); } while(advance_indices(shape, indices));
return true; return true;
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "common.hpp" #include "common.hpp"
using ADataType = F16; using ADataType = F32;
using BDataType = F16; using BDataType = F32;
// clang-format off // clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
...@@ -12,9 +12,12 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute ...@@ -12,9 +12,12 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector| // ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 4, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 3, 2, 1, 1>; < ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on // clang-format on
#include "run_permute_example.inc" #include "run_permute_example.inc"
int main(int argc, char* argv[]) { return !run_permute_example(argc, argv); } int main(int argc, char* argv[])
{
return !run_permute_example(argc, argv, {1, 16000, 80}, {0, 2, 1});
}
...@@ -21,8 +21,8 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -21,8 +21,8 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> a_lengths, b_lengths; std::array<ck::index_t, 3> a_lengths, b_lengths;
std::array<ck::index_t, 4> a_strides, b_strides; std::array<ck::index_t, 3> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer(); const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer(); void* output = b_device_buf.GetDeviceBuffer();
...@@ -64,10 +64,13 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem) ...@@ -64,10 +64,13 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
return true; return true;
} }
bool run_permute_example(int argc, char* argv[]) bool run_permute_example(int argc,
char* argv[],
const Problem::Shape& default_shape,
const Problem::Axes& default_axes)
{ {
ExecutionConfig config; ExecutionConfig config;
Problem problem; Problem problem(default_shape, default_axes);
return parse_cmd_args(argc, argv, config, problem) && run_permute(config, problem); return parse_cmd_args(argc, argv, config, problem) && run_permute(config, problem);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment