Commit 632bfff0 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Specify problem in each examples

parent 1989df75
......@@ -22,6 +22,7 @@
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
struct ExecutionConfig final
{
......@@ -31,8 +32,18 @@ struct ExecutionConfig final
struct Problem final
{
std::array<std::size_t, 4> shape = {4, 8, 16, 32};
std::array<std::size_t, 4> axes = {0, 1, 3, 2};
using Shape = std::array<std::size_t, 3>;
using Axes = Shape;
Problem() = delete;
explicit Problem(const Shape& default_shape, const Axes& default_axes)
: shape(default_shape), axes(default_axes)
{
}
Shape shape;
Axes axes;
};
template <ck::index_t... Is>
......@@ -207,7 +218,7 @@ is_valid_axes(const Axes& axes)
inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Problem& problem)
{
constexpr int num_execution_config_args = 2;
constexpr int num_problem_args = 8;
constexpr int num_problem_args = 3 + 3;
if(!(num_problem_args == size(problem.shape) + size(problem.axes)))
{
......@@ -231,13 +242,14 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
// read shape
for(std::size_t idx = 0; idx < size(problem.shape); ++idx)
{
problem.shape[idx] = std::stoi(argv[idx + 3]);
problem.shape[idx] = std::stoi(argv[idx + (1 + num_execution_config_args)]);
}
// read axes
for(std::size_t idx = 0; idx < size(problem.axes); ++idx)
{
problem.axes[idx] = std::stoi(argv[idx + size(problem.shape) + 3]);
problem.axes[idx] =
std::stoi(argv[idx + (1 + num_execution_config_args + size(problem.shape))]);
}
if(!is_valid_axes(problem.axes))
......@@ -254,8 +266,8 @@ inline bool parse_cmd_args(int argc, char* argv[], ExecutionConfig& config, Prob
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl
<< "arg3 ~ arg6: shape for 4D tensor" << std::endl
<< "arg7 ~ arg10: axes to permute" << std::endl;
<< "arg3 ~ arg5: shape for 3D tensor" << std::endl
<< "arg6 ~ arg8: axes to permute" << std::endl;
return false;
}
......@@ -369,7 +381,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
using std::size;
if(!(is_valid_axes(axes) && size(axes) == 4))
if(!(is_valid_axes(axes) && size(axes) == 3))
{
return false;
}
......@@ -377,7 +389,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(!(size(shape) == 4 && size(transposed_shape) == 4))
if(!(size(shape) == 3 && size(transposed_shape) == 3))
{
return false;
}
......@@ -394,7 +406,7 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
}
}
std::array<std::size_t, 4> indices{};
std::array<std::size_t, 3> indices{};
if(!is_valid_indices(shape, indices))
{
return false;
......@@ -403,8 +415,8 @@ host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<D
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
functor(output, src(indices[0], indices[1], indices[2]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
} while(advance_indices(shape, indices));
return true;
......
......@@ -3,8 +3,8 @@
#include "common.hpp"
using ADataType = F16;
using BDataType = F16;
using ADataType = F32;
using BDataType = F32;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
......@@ -12,9 +12,12 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| Type| Type| Operation| | Size| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 4, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 3, 2, 1, 1>;
< ADataType, BDataType, PassThrough, 3, 256, 128, 128, 0, S<1, 16, 16>, S<0, 1, 2>, 2, 1, 1, 1>;
// clang-format on
#include "run_permute_example.inc"
int main(int argc, char* argv[]) { return !run_permute_example(argc, argv); }
int main(int argc, char* argv[])
{
return !run_permute_example(argc, argv, {1, 16000, 80}, {0, 2, 1});
}
......@@ -21,8 +21,8 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
a_device_buf.ToDevice(a.mData.data());
std::array<ck::index_t, 4> a_lengths, b_lengths;
std::array<ck::index_t, 4> a_strides, b_strides;
std::array<ck::index_t, 3> a_lengths, b_lengths;
std::array<ck::index_t, 3> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
......@@ -64,10 +64,13 @@ bool run_permute(const ExecutionConfig& config, const Problem& problem)
return true;
}
bool run_permute_example(int argc, char* argv[])
bool run_permute_example(int argc,
char* argv[],
const Problem::Shape& default_shape,
const Problem::Axes& default_axes)
{
ExecutionConfig config;
Problem problem;
Problem problem(default_shape, default_axes);
return parse_cmd_args(argc, argv, config, problem) && run_permute(config, problem);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment