Commit 95f21ea5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new methods to simplify code

parent 0cfbefce
......@@ -19,18 +19,18 @@ bool run_permute_bundle(const Problem& problem)
// initialize tensor by assigning DataType values
using std::data, std::size;
{
auto* const elems = reinterpret_cast<DataType*>(data(a.mData));
auto* const elems = reinterpret_cast<DataType*>(data(a));
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
elems, elems + (size(a.mData) * NumElemsInBundle));
elems, elems + (a.GetElementSpaceSize() * NumElemsInBundle));
}
DeviceMem a_device_buf(sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BundleType) * b.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
a_device_buf.ToDevice(data(a.mData));
a_device_buf.ToDevice(data(a));
std::array<ck::index_t, 3> a_lengths, b_lengths;
std::array<ck::index_t, 3> a_strides, b_strides;
std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
......@@ -58,7 +58,7 @@ bool run_permute_bundle(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b.mData));
b_device_buf.FromDevice(data(b));
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
const auto extended_shape = extend_shape(shape, NumElemsInBundle);
......@@ -68,8 +68,7 @@ bool run_permute_bundle(const Problem& problem)
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_a(extended_shape);
std::memcpy(
data(extended_a.mData), data(a.mData), sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
std::memcpy(data(extended_a), data(a), a.GetElementSpaceSizeInBytes());
Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
......@@ -78,16 +77,15 @@ bool run_permute_bundle(const Problem& problem)
}
return ck::utils::check_err(
ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)),
b.mDesc.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{reinterpret_cast<DataType*>(data(b)),
b.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{extended_host_b.mData},
"Error: incorrect results in output tensor",
1e-6,
1e-6);
}
bool run_permute_bundle_example(const Problem::Shape& default_shape,
const Problem::Axes& default_axes)
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
{
return run_permute_bundle(Problem{default_shape, default_axes});
return run_permute_bundle(Problem{shape, axes});
}
......@@ -16,14 +16,14 @@ bool run_permute_element(const Problem& problem)
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(begin(a.mData), end(a.mData));
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
using std::data;
a_device_buf.ToDevice(data(a.mData));
a_device_buf.ToDevice(data(a));
std::array<ck::index_t, 3> a_lengths, b_lengths;
std::array<ck::index_t, 3> a_strides, b_strides;
std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer();
......@@ -51,7 +51,7 @@ bool run_permute_element(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b.mData));
b_device_buf.FromDevice(data(b));
Tensor<BDataType> host_b(transposed_shape);
if(!host_permute(a, problem.axes, PassThrough{}, host_b))
......@@ -63,8 +63,7 @@ bool run_permute_element(const Problem& problem)
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-6, 1e-6);
}
bool run_permute_element_example(const Problem::Shape& default_shape,
const Problem::Axes& default_axes)
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
{
return run_permute_element(Problem{default_shape, default_axes});
return run_permute_element(Problem{shape, axes});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment