"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "bee2e1e973dba9a2ce06b7057dfb4404f6fac17f"
Commit 0fa35b29 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use AsSpan() to shorten check_err() calls

parent c0c1d247
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <type_traits> #include <type_traits>
#include <utility>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
......
...@@ -19,9 +19,7 @@ bool run_permute_bundle(const Problem& problem) ...@@ -19,9 +19,7 @@ bool run_permute_bundle(const Problem& problem)
// initialize tensor by assigning DataType values // initialize tensor by assigning DataType values
using std::data, std::size; using std::data, std::size;
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}( ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(input_bundle_tensor.AsSpan<DataType>());
ck::span<DataType>{reinterpret_cast<DataType*>(data(input_bundle_tensor)),
input_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle});
DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes()); DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes()); DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes());
...@@ -72,13 +70,11 @@ bool run_permute_bundle(const Problem& problem) ...@@ -72,13 +70,11 @@ bool run_permute_bundle(const Problem& problem)
return false; return false;
} }
return ck::utils::check_err( return ck::utils::check_err(output_bundle_tensor.AsSpan<const DataType>(),
ck::span<const DataType>{reinterpret_cast<DataType*>(data(output_bundle_tensor)), output_tensor.AsSpan<const DataType>(),
output_bundle_tensor.GetElementSpaceSize() * NumElemsInBundle}, "Error: incorrect results in output tensor",
ck::span<const DataType>{output_tensor}, 1e-6,
"Error: incorrect results in output tensor", 1e-6);
1e-6,
1e-6);
} }
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes) bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
......
...@@ -57,8 +57,8 @@ bool run_permute_element(const Problem& problem) ...@@ -57,8 +57,8 @@ bool run_permute_element(const Problem& problem)
return false; return false;
} }
return ck::utils::check_err(output_tensor.mData, return ck::utils::check_err(output_tensor.AsSpan<const OutDataType>(),
output_tensor_host.mData, output_tensor_host.AsSpan<const OutDataType>(),
"Error: incorrect results in output tensor", "Error: incorrect results in output tensor",
1e-6, 1e-6,
1e-6); 1e-6);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment