"docs/source/en/vscode:/vscode.git/clone" did not exist on "1a6fa69ab610586dad912c2b8d72bef9e3f209ee"
Commit 15965dfc authored by Astha Rai's avatar Astha Rai
Browse files

changed dimension for debugging

parent 3a9e6db3
...@@ -26,8 +26,8 @@ using DeviceElementwisePermuteInstance = ...@@ -26,8 +26,8 @@ using DeviceElementwisePermuteInstance =
1, 1,
8, 8,
8, 8,
ck::Sequence<8>, ck::Sequence<1>,
ck::Sequence<8>>; ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor> template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc, void host_elementwise4D(HostTensorB& B_nhwc,
...@@ -50,8 +50,9 @@ int main() ...@@ -50,8 +50,9 @@ int main()
bool do_verification = true; bool do_verification = true;
bool time_kernel = true; bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 4, 8, 8}; std::vector<std::size_t> nchw = {4, 8, 4, 8};
std::vector<std::size_t> nhwc = {4, 8, 8, 4}; std::vector<std::size_t> nhwc = {4, 4, 8, 8};
Tensor<ADataType> a(nchw); Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc); Tensor<BDataType> b(nhwc);
...@@ -61,23 +62,26 @@ int main() ...@@ -61,23 +62,26 @@ int main()
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data()); a_device_buf.ToDevice(a.mData.data());
//LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()}; std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths; std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]), std::array<ck::index_t, 4> a_strides = {static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]),
static_cast<int>(nchw[2] * nchw[3]), static_cast<int>(nhwc[2]),
static_cast<int>(nchw[3]), 1,
static_cast<int>(nhwc[1] * nhwc[2])};
std::array<ck::index_t, 4> b_strides = {static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[2]*nhwc[3]),
static_cast<int>(nhwc[3]),
1}; 1};
std::array<ck::index_t, 4> b_strides = {
static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]), 1, 32, 4};
// std::cout << "Length: " << ab_lengths << std::endl; // std::cout << "Length: " << ab_lengths << std::endl;
// std::cout << "A stride: " << a_strides << std::endl; // std::cout << "A stride: " << a_strides << std::endl;
// std::cout << "B stride: " << b_strides << std::endl; // std::cout << "B stride: " << b_strides << std::endl;
std::copy(nchw.begin(), nchw.end(), ab_lengths.begin()); std::copy(nhwc.begin(), nhwc.end(), ab_lengths.begin());
// std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin()); // std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin());
// std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin()); // std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin());
...@@ -101,12 +105,12 @@ int main() ...@@ -101,12 +105,12 @@ int main()
if(do_verification) if(do_verification)
{ {
b_device_buf.FromDevice(b.mData.data()); b_device_buf.FromDevice(b.mData.data());
//LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nhwc); Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>( host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{}); host_b, a, nchw, PassThrough{});
//LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl; LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &= pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment