Commit 19147f59 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use more specific method to write example

parent 665b73ff
...@@ -24,7 +24,7 @@ void host_elementwise4D(HostTensorB& B, ...@@ -24,7 +24,7 @@ void host_elementwise4D(HostTensorB& B,
bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem) bool run_elementwise_permute(const ExecutionConfig& config, const Problem& problem)
{ {
std::size_t N = 4, C = 16, H = 32, W = 32; auto [N, C, H, W] = problem.shape;
std::vector<std::size_t> nchw = {N, C, H, W}; std::vector<std::size_t> nchw = {N, C, H, W};
std::vector<std::size_t> nhwc = {N, H, W, C}; std::vector<std::size_t> nhwc = {N, H, W, C};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
...@@ -48,19 +48,19 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl ...@@ -48,19 +48,19 @@ bool run_elementwise_permute(const ExecutionConfig& config, const Problem& probl
std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin()); std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin());
std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin()); std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin());
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto permute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument =
ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); permute.MakeArgument(ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{});
if(!broadcastPermute.IsSupportedArgument(argument.get())) if(!permute.IsSupportedArgument(argument))
{ {
throw std::runtime_error( std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
"The runtime parameters seems not supported by the device instance, exiting!"); << std::endl;
return false;
}; };
auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); auto invoker = permute.MakeInvoker();
float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
StreamConfig{nullptr, config.time_kernel});
std::cout << "Perf: " << ave_time << " ms" << std::endl; std::cout << "Perf: " << ave_time << " ms" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment