Commit d53443d5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use more meaningful names in permute bundle example

parent 7b135645
......@@ -9,39 +9,51 @@ bool run_permute_bundle(const Problem& problem)
using std::begin, std::end;
const auto& shape = problem.shape;
ck::remove_cvref_t<decltype(shape)> transposed_shape;
transpose_shape(problem.shape, problem.axes, begin(transposed_shape));
const auto& input_bundle_shape = problem.shape;
const auto& input_bundle_axes = problem.axes;
Tensor<BundleType> a(shape);
Tensor<BundleType> b(transposed_shape);
ck::remove_cvref_t<decltype(input_bundle_shape)> output_bundle_shape;
transpose_shape(input_bundle_shape, input_bundle_axes, begin(output_bundle_shape));
Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
// initialize tensor by assigning DataType values
using std::data, std::size;
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(ck::span<DataType>{
reinterpret_cast<DataType*>(data(a)), a.GetElementSpaceSize() * NumElemsInBundle});
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
ck::span<DataType>{reinterpret_cast<DataType*>(data(input_bundle_tensor)),
input_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle});
DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes());
a_device_buf.ToDevice(data(a));
input_device_buf.ToDevice(data(input_bundle_tensor));
std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
std::array<ck::index_t, Problem::NumDim> input_bundle_lengths, output_bundle_lengths;
std::array<ck::index_t, Problem::NumDim> input_bundle_strides, output_bundle_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
const void* input_bundle_data = input_device_buf.GetDeviceBuffer();
void* output_bundle_data = output_device_buf.GetDeviceBuffer();
std::copy(begin(shape), end(shape), begin(a_lengths));
std::copy(begin(a.GetStrides()), end(a.GetStrides()), begin(a_strides));
std::copy(begin(transposed_shape), end(transposed_shape), begin(b_lengths));
std::copy(begin(b.GetStrides()), end(b.GetStrides()), begin(b_strides));
std::copy(begin(input_bundle_shape), end(input_bundle_shape), begin(input_bundle_lengths));
std::copy(begin(input_bundle_tensor.GetStrides()),
end(input_bundle_tensor.GetStrides()),
begin(input_bundle_strides));
std::copy(begin(output_bundle_shape), end(output_bundle_shape), begin(output_bundle_lengths));
std::copy(begin(output_bundle_tensor.GetStrides()),
end(output_bundle_tensor.GetStrides()),
begin(output_bundle_strides));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(
a_lengths, a_strides, b_lengths, b_strides, input, output, PassThrough{});
auto argument = permute.MakeArgument(input_bundle_lengths,
input_bundle_strides,
output_bundle_lengths,
output_bundle_strides,
input_bundle_data,
output_bundle_data,
PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
......@@ -55,28 +67,30 @@ bool run_permute_bundle(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b));
output_device_buf.FromDevice(data(output_bundle_tensor));
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
const auto extended_shape = extend_shape(shape, NumElemsInBundle);
const auto extended_axes = extend_axes(problem.axes);
const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
const auto input_axes = extend_axes(input_bundle_axes);
ck::remove_cvref_t<decltype(extended_shape)> transposed_extended_shape;
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
ck::remove_cvref_t<decltype(input_shape)> output_shape;
transpose_shape(input_shape, input_axes, begin(output_shape));
Tensor<DataType> extended_a(extended_shape);
std::memcpy(data(extended_a), data(a), a.GetElementSpaceSizeInBytes());
Tensor<DataType> input_tensor(input_shape);
std::memcpy(data(input_tensor),
data(input_bundle_tensor),
input_bundle_tensor.GetElementSpaceSizeInBytes());
Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
Tensor<DataType> output_tensor(output_shape);
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
{
return false;
}
return ck::utils::check_err(
ck::span<const DataType>{reinterpret_cast<DataType*>(data(b)),
b.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{extended_host_b.mData},
ck::span<const DataType>{reinterpret_cast<DataType*>(data(output_bundle_tensor)),
output_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{output_tensor},
"Error: incorrect results in output tensor",
1e-6,
1e-6);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment