debug.hpp 2.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP

namespace ck {
namespace debug {

namespace detail {
template <typename T, typename Enable = void>
struct PrintAsType;

template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value>
{
    using type = float;
};

template <>
struct PrintAsType<ck::half_t, void>
{
    using type = float;
};

template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value>
{
    using type = int;
};
} // namespace detail

// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer
// and the number of elements. Can optionally specify strides between elements and how many bytes'
// worth of data per row.
//
// Usage example:
//
//   debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
//
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements)
{
    using PrintType                = typename detail::PrintAsType<T>::type;
    constexpr index_t row_elements = row_bytes / sizeof(T);
    static_assert((element_stride >= 1 && element_stride <= row_elements),
                  "element_stride should between [1, row_elements]");

    index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
    index_t tid =
        (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;

    __syncthreads();

    if(tid == 0)
    {
        printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n",
               wgid,
               row_bytes,
               element_stride);
        for(index_t i = 0; i < num_elements; i += row_elements)
        {
            printf("elem %5d: ", i);
            for(index_t j = 0; j < row_elements; j += element_stride)
            {
                printf("%.0f ", static_cast<PrintType>(p_shared[i + j]));
            }

            printf("\n");
        }
        printf("\n");
    }

    __syncthreads();
}

} // namespace debug
} // namespace ck

#endif // UTILITY_DEBUG_HPP