"vscode:/vscode.git/clone" did not exist on "76696dca558267999abf3e7c29e1a256cbcb407a"
Commit 95f21ea5 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use new methods to simplify code

parent 0cfbefce
...@@ -19,18 +19,18 @@ bool run_permute_bundle(const Problem& problem) ...@@ -19,18 +19,18 @@ bool run_permute_bundle(const Problem& problem)
// initialize tensor by assigning DataType values // initialize tensor by assigning DataType values
using std::data, std::size; using std::data, std::size;
{ {
auto* const elems = reinterpret_cast<DataType*>(data(a.mData)); auto* const elems = reinterpret_cast<DataType*>(data(a));
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}( ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(
elems, elems + (size(a.mData) * NumElemsInBundle)); elems, elems + (a.GetElementSpaceSize() * NumElemsInBundle));
} }
DeviceMem a_device_buf(sizeof(BundleType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(sizeof(BundleType) * b.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
a_device_buf.ToDevice(data(a.mData)); a_device_buf.ToDevice(data(a));
std::array<ck::index_t, 3> a_lengths, b_lengths; std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, 3> a_strides, b_strides; std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer(); const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer(); void* output = b_device_buf.GetDeviceBuffer();
...@@ -58,7 +58,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -58,7 +58,7 @@ bool run_permute_bundle(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl; std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b.mData)); b_device_buf.FromDevice(data(b));
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle] // extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
const auto extended_shape = extend_shape(shape, NumElemsInBundle); const auto extended_shape = extend_shape(shape, NumElemsInBundle);
...@@ -68,8 +68,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -68,8 +68,7 @@ bool run_permute_bundle(const Problem& problem)
transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape)); transpose_shape(extended_shape, extended_axes, begin(transposed_extended_shape));
Tensor<DataType> extended_a(extended_shape); Tensor<DataType> extended_a(extended_shape);
std::memcpy( std::memcpy(data(extended_a), data(a), a.GetElementSpaceSizeInBytes());
data(extended_a.mData), data(a.mData), sizeof(BundleType) * a.mDesc.GetElementSpaceSize());
Tensor<DataType> extended_host_b(transposed_extended_shape); Tensor<DataType> extended_host_b(transposed_extended_shape);
if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b)) if(!host_permute(extended_a, extended_axes, PassThrough{}, extended_host_b))
...@@ -78,16 +77,15 @@ bool run_permute_bundle(const Problem& problem) ...@@ -78,16 +77,15 @@ bool run_permute_bundle(const Problem& problem)
} }
return ck::utils::check_err( return ck::utils::check_err(
ck::span<const DataType>{reinterpret_cast<DataType*>(data(b.mData)), ck::span<const DataType>{reinterpret_cast<DataType*>(data(b)),
b.mDesc.GetElementSpaceSize() * NumElemsInBundle}, b.GetElementSpaceSize() * NumElemsInBundle},
ck::span<const DataType>{extended_host_b.mData}, ck::span<const DataType>{extended_host_b.mData},
"Error: incorrect results in output tensor", "Error: incorrect results in output tensor",
1e-6, 1e-6,
1e-6); 1e-6);
} }
bool run_permute_bundle_example(const Problem::Shape& default_shape, bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
const Problem::Axes& default_axes)
{ {
return run_permute_bundle(Problem{default_shape, default_axes}); return run_permute_bundle(Problem{shape, axes});
} }
...@@ -16,14 +16,14 @@ bool run_permute_element(const Problem& problem) ...@@ -16,14 +16,14 @@ bool run_permute_element(const Problem& problem)
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(begin(a.mData), end(a.mData)); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(begin(a.mData), end(a.mData));
DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(a.GetElementSpaceSizeInBytes());
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(b.GetElementSpaceSizeInBytes());
using std::data; using std::data;
a_device_buf.ToDevice(data(a.mData)); a_device_buf.ToDevice(data(a));
std::array<ck::index_t, 3> a_lengths, b_lengths; std::array<ck::index_t, Problem::NumDim> a_lengths, b_lengths;
std::array<ck::index_t, 3> a_strides, b_strides; std::array<ck::index_t, Problem::NumDim> a_strides, b_strides;
const void* input = a_device_buf.GetDeviceBuffer(); const void* input = a_device_buf.GetDeviceBuffer();
void* output = b_device_buf.GetDeviceBuffer(); void* output = b_device_buf.GetDeviceBuffer();
...@@ -51,7 +51,7 @@ bool run_permute_element(const Problem& problem) ...@@ -51,7 +51,7 @@ bool run_permute_element(const Problem& problem)
std::cout << "Perf: " << ave_time << " ms" << std::endl; std::cout << "Perf: " << ave_time << " ms" << std::endl;
b_device_buf.FromDevice(data(b.mData)); b_device_buf.FromDevice(data(b));
Tensor<BDataType> host_b(transposed_shape); Tensor<BDataType> host_b(transposed_shape);
if(!host_permute(a, problem.axes, PassThrough{}, host_b)) if(!host_permute(a, problem.axes, PassThrough{}, host_b))
...@@ -63,8 +63,7 @@ bool run_permute_element(const Problem& problem) ...@@ -63,8 +63,7 @@ bool run_permute_element(const Problem& problem)
b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-6, 1e-6); b.mData, host_b.mData, "Error: incorrect results in output tensor", 1e-6, 1e-6);
} }
bool run_permute_element_example(const Problem::Shape& default_shape, bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
const Problem::Axes& default_axes)
{ {
return run_permute_element(Problem{default_shape, default_axes}); return run_permute_element(Problem{shape, axes});
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment