Commit 4dab86fe authored by Astha Rai's avatar Astha Rai
Browse files

update profiler and client example tensor layouts

parent 0f8c6a60
...@@ -48,8 +48,8 @@ int main() ...@@ -48,8 +48,8 @@ int main()
auto size = N * C * D * H * W; auto size = N * C * D * H * W;
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D}; std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, 1, D * H, D}; std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D
SimpleDeviceMem a_dev_buf(sizeof(ADataType) * size); SimpleDeviceMem a_dev_buf(sizeof(ADataType) * size);
SimpleDeviceMem b_dev_buf(sizeof(BDataType) * size); SimpleDeviceMem b_dev_buf(sizeof(BDataType) * size);
......
...@@ -73,7 +73,7 @@ int main() ...@@ -73,7 +73,7 @@ int main()
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D}; std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto broadcastPermute = DeviceElementwisePermuteInstance{};
auto argument = broadcastPermute.MakeArgumentPointer( auto argument = broadcastPermute.MakeArgumentPointer(
......
...@@ -38,9 +38,9 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -38,9 +38,9 @@ bool profile_gemm_splitk_impl(int do_verification,
bool pass = true; bool pass = true;
std::vector<std::size_t> ncdhw = {N, C, D, H, W}; std::vector<std::size_t> ncdhw = {N, C, D, H, W};
std::vector<std::size_t> nchwd = {N, C, H, W, D}; std::vector<std::size_t> ndhwc = {N, D, H, W, C};
Tensor<ADataType> a(ncdhw); Tensor<ADataType> a(ncdhw);
Tensor<BDataType> b(nchwd); Tensor<BDataType> b(ndhwc);
// a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); // a.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
...@@ -48,8 +48,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -48,8 +48,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()}; std::array<void*, 1> output = {b_device_buf.GetDeviceBuffer()};
std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D}; std::array<ck::index_t, 5> ab_lengths{N, C, H, W, D};
std::array<ck::index_t, 5> a_strides = {C * D * H * W, D * H * W, 1, D * H, D}; std::array<ck::index_t, 5> a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W
std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; std::array<ck::index_t, 5> b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C
std::cout << "A: " << a.mDesc << std::endl; std::cout << "A: " << a.mDesc << std::endl;
std::cout << "B: " << b.mDesc << std::endl; std::cout << "B: " << b.mDesc << std::endl;
...@@ -192,4 +192,4 @@ return pass; ...@@ -192,4 +192,4 @@ return pass;
} }
} // namespace profiler } // namespace profiler
} // namespace ck } // namespace ck
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment