"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "66de221409f8d17222348a7c0ca60f8322e278b4"
Commit 06fd9eaa authored by Astha Rai's avatar Astha Rai
Browse files

updating formatting

parent e73a2cb7
...@@ -34,11 +34,11 @@ void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functo ...@@ -34,11 +34,11 @@ void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functo
for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c)
for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d)
for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h)
for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w)
{ {
auto a_val = A_ncdhw(n, c, d, h, w); auto a_val = A_ncdhw(n, c, d, h, w);
functor(B_nchwd(n, c, h, w, d), a_val); functor(B_nchwd(n, c, h, w, d), a_val);
} }
} }
int main() int main()
...@@ -46,8 +46,6 @@ int main() ...@@ -46,8 +46,6 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
//std::vector<std::size_t> ncdhw = {16, 128, 32, 64, 16};
//std::vector<std::size_t> nchwd = {16, 128, 64, 16, 32};
std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8}; std::vector<std::size_t> ncdhw = {16, 8, 8, 8, 8};
std::vector<std::size_t> nchwd = {16, 8, 8, 8, 8}; std::vector<std::size_t> nchwd = {16, 8, 8, 8, 8};
Tensor<ADataType> a(ncdhw); Tensor<ADataType> a(ncdhw);
...@@ -64,16 +62,18 @@ int main() ...@@ -64,16 +62,18 @@ int main()
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths; std::array<ck::index_t, 5> ab_lengths;
std::array<ck::index_t, 5> a_strides = {static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), std::array<ck::index_t, 5> a_strides = {
static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]), static_cast<int>(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[3] * ncdhw[4]), static_cast<int>(ncdhw[2] * ncdhw[3] * ncdhw[4]),
static_cast<int>(ncdhw[4]), static_cast<int>(ncdhw[3] * ncdhw[4]),
1}; static_cast<int>(ncdhw[4]),
std::array<ck::index_t, 5> b_strides = {static_cast<int>(nchwd[1] * nchwd[2] * nchwd[3] * nchwd[4]), 1};
static_cast<int>(nchwd[2] * nchwd[3] * nchwd[4]), std::array<ck::index_t, 5> b_strides = {
1, static_cast<int>(nchwd[1] * nchwd[2] * nchwd[3] * nchwd[4]),
static_cast<int>(nchwd[3] * nchwd[4]), static_cast<int>(nchwd[2] * nchwd[3] * nchwd[4]),
static_cast<int>(nchwd[4])}; 1,
static_cast<int>(nchwd[3] * nchwd[4]),
static_cast<int>(nchwd[4])};
ck::ranges::copy(ncdhw, ab_lengths.begin()); ck::ranges::copy(ncdhw, ab_lengths.begin());
...@@ -95,17 +95,15 @@ int main() ...@@ -95,17 +95,15 @@ int main()
broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4];
std::size_t num_btype = sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + std::size_t num_btype =
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) +
sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
//LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "A : ", a.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//std::cout << "A: " << a.mData.data() << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl; << std::endl;
...@@ -117,8 +115,8 @@ int main() ...@@ -117,8 +115,8 @@ int main()
Tensor<BDataType> host_b(nchwd); Tensor<BDataType> host_b(nchwd);
host_elementwise4D(host_b, a, PassThrough{}); host_elementwise4D(host_b, a, PassThrough{});
//LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "B : ", b.mData, ",") << std::endl;
//LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl; // LogRangeAsType<float>(std::cout << "Host B : ", host_b.mData, ",") << std::endl;
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment