"...zh/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "af769881d37fe916afef2c47279f66c79f5f2714"
Commit b472cdf6 authored by Astha Rai's avatar Astha Rai
Browse files

added working example for 5D input using 1D kernel

parent d52ec016
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
add_example_executable(example_elementwise_permute_5D elementwise_permute_5D.cpp)
add_example_executable(example_elementwise_permute_5D_2d elementwise_permute_5D_2d.cpp)
endif() endif()
...@@ -45,8 +45,10 @@ int main() ...@@ -45,8 +45,10 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {16, 128, 32, 64}; //std::vector<std::size_t> nchw = {16, 128, 32, 64};
std::vector<std::size_t> nhwc = {16, 32, 64, 128}; //std::vector<std::size_t> nhwc = {16, 32, 64, 128};
std::vector<std::size_t> nchw = {16, 8, 8, 8};
std::vector<std::size_t> nhwc = {16, 8, 8, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
...@@ -99,7 +101,8 @@ int main() ...@@ -99,7 +101,8 @@ int main()
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
...@@ -107,6 +110,8 @@ int main() ...@@ -107,6 +110,8 @@ int main()
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D(host_b, a, PassThrough{}); host_elementwise4D(host_b, a, PassThrough{});
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", host_b.mData, ",") << std::endl;
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment