debug.hpp 2.49 KB
Newer Older
Umang Yadav's avatar
Umang Yadav committed
1
2
3

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
Chao Liu's avatar
Chao Liu committed
4
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
5
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
6

7
8
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
Umang Yadav's avatar
Umang Yadav committed
9
#include "type.hpp"
10
11
12
13
namespace ck {
namespace debug {

namespace detail {
Umang Yadav's avatar
Umang Yadav committed
14
template <typename T, typename Enable = void> struct PrintAsType;
15
16

template <typename T>
Umang Yadav's avatar
Umang Yadav committed
17
18
19
20
21
22
struct PrintAsType<
    T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
  using type = float;
  __host__ __device__ static void Print(const T &p) {
    printf("%.3f ", static_cast<type>(p));
  }
23
24
};

Umang Yadav's avatar
Umang Yadav committed
25
26
27
28
29
template <> struct PrintAsType<ck::half_t, void> {
  using type = float;
  __host__ __device__ static void Print(const ck::half_t &p) {
    printf("%.3f ", static_cast<type>(p));
  }
30
31
32
};

template <typename T>
Umang Yadav's avatar
Umang Yadav committed
33
34
35
36
37
38
struct PrintAsType<T,
                   typename std::enable_if<std::is_integral<T>::value>::type> {
  using type = int;
  __host__ __device__ static void Print(const T &p) {
    printf("%d ", static_cast<type>(p));
  }
39
40
41
};
} // namespace detail

Umang Yadav's avatar
Umang Yadav committed
42
43
44
// Print at runtime the data in shared memory in 128 bytes per row format given
// shared mem pointer and the number of elements. Can optionally specify strides
// between elements and how many bytes' worth of data per row.
45
46
47
//
// Usage example:
//
Umang Yadav's avatar
Umang Yadav committed
48
49
//   debug::print_shared(a_block_buf.p_data_,
//   index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
50
51
//
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
Umang Yadav's avatar
Umang Yadav committed
52
53
54
55
__device__ void print_shared(T const *p_shared, index_t num_elements) {
  constexpr index_t row_elements = row_bytes / sizeof(T);
  static_assert((element_stride >= 1 && element_stride <= row_elements),
                "element_stride should between [1, row_elements]");
56

Umang Yadav's avatar
Umang Yadav committed
57
58
59
60
  index_t wgid =
      blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
  index_t tid = (threadIdx.z * (blockDim.x * blockDim.y)) +
                (threadIdx.y * blockDim.x) + threadIdx.x;
61

Umang Yadav's avatar
Umang Yadav committed
62
  __syncthreads();
63

Umang Yadav's avatar
Umang Yadav committed
64
65
66
67
68
69
70
71
  if (tid == 0) {
    printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", wgid,
           row_bytes, element_stride);
    for (index_t i = 0; i < num_elements; i += row_elements) {
      printf("elem %5d: ", i);
      for (index_t j = 0; j < row_elements; j += element_stride) {
        detail::PrintAsType<T>::Print(p_shared[i + j]);
      }
72

Umang Yadav's avatar
Umang Yadav committed
73
      printf("\n");
74
    }
Umang Yadav's avatar
Umang Yadav committed
75
76
    printf("\n");
  }
77

Umang Yadav's avatar
Umang Yadav committed
78
  __syncthreads();
79
80
81
82
83
84
}

} // namespace debug
} // namespace ck

#endif // UTILITY_DEBUG_HPP
Umang Yadav's avatar
Umang Yadav committed
85
86

#pragma clang diagnostic pop