"test/vscode:/vscode.git/clone" did not exist on "e4e99a49bfdb1bdf8d620aa01ea5608028e390a8"
Commit 60ab70d8 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Generalize variable naming in example code

parent 31d758fb
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem) bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem)
{ {
const auto& nchw = problem.shape; using std::begin, std::end;
std::vector<std::size_t> nhwc;
transpose_shape(problem.shape, problem.axes, std::back_inserter(nhwc));
Tensor<ADataType> a(nchw); const auto& shape = problem.shape;
Tensor<BDataType> b(nhwc); ck::remove_cvref_t<decltype(shape)> transposed_shape;
transpose_shape(problem.shape, problem.axes, begin(transposed_shape));
Tensor<ADataType> a(shape);
Tensor<BDataType> b(transposed_shape);
std::iota(begin(a.mData), end(a.mData), 0); std::iota(begin(a.mData), end(a.mData), 0);
...@@ -23,12 +25,11 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl ...@@ -23,12 +25,11 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths; std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides; std::array<ck::index_t, 4> a_strides, b_strides;
std::array<ck::index_t, 4> b_strides;
std::copy(nchw.begin(), nchw.end(), ab_lengths.begin()); std::copy(begin(shape), end(shape), begin(ab_lengths));
std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin()); std::copy(begin(a.mDesc.GetStrides()), end(a.mDesc.GetStrides()), begin(a_strides));
std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin()); std::copy(begin(b.mDesc.GetStrides()), end(b.mDesc.GetStrides()), begin(b_strides));
auto permute = DeviceElementwisePermuteInstance{}; auto permute = DeviceElementwisePermuteInstance{};
auto argument = auto argument =
...@@ -48,7 +49,7 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl ...@@ -48,7 +49,7 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
if(config.do_verification) if(config.do_verification)
{ {
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(transposed_shape);
host_elementwise_permute(a, problem.axes, PassThrough{}, host_b); host_elementwise_permute(a, problem.axes, PassThrough{}, host_b);
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment