"examples/dreambooth/train_dreambooth_lora.py" did not exist on "4bf675f4652759b42280103cb84ab0101cf23382"
debug.hpp 2.45 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
// SPDX-License-Identifier: MIT
Illia Silin's avatar
Illia Silin committed
2
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
Chao Liu's avatar
Chao Liu committed
3

4
5
6
7
8
9
10
11
12
13
14
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP

namespace ck {
namespace debug {

namespace detail {
template <typename T, typename Enable = void>
struct PrintAsType;

template <typename T>
15
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
16
17
{
    using type = float;
18
    __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
19
20
21
22
23
24
};

template <>
struct PrintAsType<ck::half_t, void>
{
    using type = float;
25
26
27
28
    __host__ __device__ static void Print(const ck::half_t& p)
    {
        printf("%.3f ", static_cast<type>(p));
    }
29
30
31
};

template <typename T>
32
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
33
34
{
    using type = int;
35
    __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
};
} // namespace detail

// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer
// and the number of elements. Can optionally specify strides between elements and how many bytes'
// worth of data per row.
//
// Usage example:
//
//   debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
//
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements)
{
    constexpr index_t row_elements = row_bytes / sizeof(T);
    static_assert((element_stride >= 1 && element_stride <= row_elements),
                  "element_stride should between [1, row_elements]");

    index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
    index_t tid =
        (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;

    __syncthreads();

    if(tid == 0)
    {
        printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n",
               wgid,
               row_bytes,
               element_stride);
        for(index_t i = 0; i < num_elements; i += row_elements)
        {
            printf("elem %5d: ", i);
            for(index_t j = 0; j < row_elements; j += element_stride)
            {
71
                detail::PrintAsType<T>::Print(p_shared[i + j]);
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            }

            printf("\n");
        }
        printf("\n");
    }

    __syncthreads();
}

} // namespace debug
} // namespace ck

#endif // UTILITY_DEBUG_HPP