Commit ff6a04fd authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use more meaningful names in permute element examples

parent d53443d5
......@@ -3,16 +3,16 @@
#include "common.hpp"
using ADataType = F16;
using BDataType = F16;
using InDataType = F16;
using OutDataType = F16;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 256, 1, 32, 32, 3, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 2, 1>;
// ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< InDataType, OutDataType, PassThrough, 3, 256, 1, 32, 32, 3, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 2, 1>;
// clang-format on
#include "run_permute_element_example.inc"
......
......@@ -3,16 +3,16 @@
#include "common.hpp"
using ADataType = F16;
using BDataType = F16;
using InDataType = F16;
using OutDataType = F16;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermute
// ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< ADataType, BDataType, PassThrough, 3, 128, 4, 16, 8, 6, S<2, 16, 4>, S<0, 1, 2>, 2, 1, 2, 1>;
// ######| InData| OutData| Elementwise| NumDim| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| Type| Type| Operation| | Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< InDataType, OutDataType, PassThrough, 3, 128, 4, 16, 8, 6, S<2, 16, 4>, S<0, 1, 2>, 2, 1, 2, 1>;
// clang-format on
#include "run_permute_element_example.inc"
......
......@@ -7,37 +7,46 @@ bool run_permute_element(const Problem& problem)
{
using std::begin, std::end;
const auto& shape = problem.shape;
ck::remove_cvref_t<decltype(shape)> transposed_shape;
transpose_shape(problem.shape, problem.axes, begin(transposed_shape));
const auto& input_shape = problem.shape;
const auto& input_axes = problem.axes;
Tensor<ADataType> a(shape);
Tensor<BDataType> b(transposed_shape);
ck::remove_cvref_t<decltype(input_shape)> output_shape;
transpose_shape(input_shape, input_axes, begin(output_shape));
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a);
Tensor<InDataType> input_tensor(input_shape);
Tensor<OutDataType> output_tensor(output_shape);
DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor);
DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes());
using std::data;
a_device_buf.ToDevice(data(a));
input_device_buf.ToDevice(data(input_tensor));
std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
std::array<ck::index_t, Problem::NumDim> input_lengths, output_lengths;
std::array<ck::index_t, Problem::NumDim> input_strides, output_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
const void* input_data = input_device_buf.GetDeviceBuffer();
void* output_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(a_lengths));
std::copy(begin(a.GetStrides()), end(a.GetStrides()), begin(a_strides));
std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
std::copy(begin(b.GetStrides()), end(b.GetStrides()), begin(b_strides));
std::copy(begin(input_shape), end(input_shape), begin(input_lengths));
std::copy(
begin(input_tensor.GetStrides()), end(input_tensor.GetStrides()), begin(input_strides));
std::copy(begin(output_shape), end(output_shape), begin(output_lengths));
std::copy(
begin(output_tensor.GetStrides()), end(output_tensor.GetStrides()), begin(output_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(
a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
auto argument = permute.MakeArgument(input_lengths,
input_strides,
output_lengths,
output_strides,
input_data,
output_data,
PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
......@@ -51,16 +60,19 @@ bool run_permute_element(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b));
output_device_buf.FromDevice(data(output_tensor));
Tensor<BDataType> host_b(transposed_shape);
if(!host_permute(a, problem.axes, PassThrough{}, host_b))
Tensor<OutDataType> output_tensor_host(output_shape);
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor_host))
{
return false;
}
return ck::utils::check_err(
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-6, 1e-6);
return ck::utils::check_err(output_tensor.mData,
output_tensor_host.mData,
"Error: incorrect results in output tensor",
1e-6,
1e-6);
}
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment