Commit 15965dfc authored by Astha Rai's avatar Astha Rai
Browse files

changed dimension for debugging

parent 3a9e6db3
......@@ -26,8 +26,8 @@ using DeviceElementwisePermuteInstance =
1,
8,
8,
ck::Sequence<8>,
ck::Sequence<8>>;
ck::Sequence<1>,
ck::Sequence<1>>;
template <typename HostTensorA, typename HostTensorB, typename Functor>
void host_elementwise4D(HostTensorB& B_nhwc,
......@@ -50,8 +50,9 @@ int main()
bool do_verification = true;
bool time_kernel = true;
std::vector<std::size_t> nchw = {4, 4, 8, 8};
std::vector<std::size_t> nhwc = {4, 8, 8, 4};
std::vector<std::size_t> nchw = {4, 8, 4, 8};
std::vector<std::size_t> nhwc = {4, 4, 8, 8};
Tensor<ADataType> a(nchw);
Tensor<BDataType> b(nhwc);
......@@ -61,23 +62,26 @@ int main()
DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a.mData.data());
//LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "Tensor a : ", a.mData, ",") << std::endl;
std::array<const void*, 1> input = {a_device_buf.GetDeviceBuffer()};
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 4> ab_lengths;
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nchw[1] * nchw[2] * nchw[3]),
static_cast<int>(nchw[2] * nchw[3]),
static_cast<int>(nchw[3]),
std::array<ck::index_t, 4> a_strides = {static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[2]),
1,
static_cast<int>(nhwc[1] * nhwc[2])};
std::array<ck::index_t, 4> b_strides = {static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]),
static_cast<int>(nhwc[2]*nhwc[3]),
static_cast<int>(nhwc[3]),
1};
std::array<ck::index_t, 4> b_strides = {
static_cast<int>(nhwc[1] * nhwc[2] * nhwc[3]), 1, 32, 4};
// std::cout << "Length: " << ab_lengths << std::endl;
// std::cout << "A stride: " << a_strides << std::endl;
// std::cout << "B stride: " << b_strides << std::endl;
std::copy(nchw.begin(), nchw.end(), ab_lengths.begin());
std::copy(nhwc.begin(), nhwc.end(), ab_lengths.begin());
// std::copy(a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end(), a_strides.begin());
// std::copy(b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end(), b_strides.begin());
......@@ -101,12 +105,12 @@ int main()
if(do_verification)
{
b_device_buf.FromDevice(b.mData.data());
//LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "Tensor b : ", b.mData, ",") << std::endl;
Tensor<BDataType> host_b(nhwc);
host_elementwise4D<Tensor<ADataType>, Tensor<BDataType>, PassThrough>(
host_b, a, nchw, PassThrough{});
//LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "Host b : ", host_b.mData, ",") << std::endl;
pass &=
ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment