gemm.cuh 5.12 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
#pragma once

template <class ThreadMatrixA,
          class ThreadMatrixB,
Chao Liu's avatar
Chao Liu committed
5
6
          class ThreadMatrixC,
          bool TransA,
Chao Liu's avatar
Chao Liu committed
7
          bool TransB,
Chao Liu's avatar
Chao Liu committed
8
9
          bool TransC,
          class FloatA,
Chao Liu's avatar
Chao Liu committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
          class FloatB,
          class FloatC,
          class Accumulator>
__device__ void threadwise_gemm(ThreadMatrixA,
                                Constant<bool, TransA>,
                                FloatA* const p_a_thread,
                                ThreadMatrixB,
                                Constant<bool, TransB>,
                                FloatB* const p_b_thread,
                                ThreadMatrixC,
                                Constant<bool, TransC>,
                                FloatC* p_c_thread,
                                Accumulator)
{
    // do something
}

template <unsigned BlockSize,
          class BlockMatrixA,
          class BlockMatrixB,
Chao Liu's avatar
Chao Liu committed
30
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
31
32
          bool TransA,
          bool TransB,
Chao Liu's avatar
Chao Liu committed
33
          bool TransC,
Chao Liu's avatar
Chao Liu committed
34
35
          unsigned BlockMatrixStrideA,
          unsigned BlockMatrixStrideB,
Chao Liu's avatar
Chao Liu committed
36
37
          unsigned ThreadMatrixStrideC,
          unsigned BatchSize,
Chao Liu's avatar
Chao Liu committed
38
          unsigned BatchPerThread,
Chao Liu's avatar
Chao Liu committed
39
          unsigned KPerLoop,
Chao Liu's avatar
Chao Liu committed
40
41
42
          class Accumulator>
struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
{
Chao Liu's avatar
Chao Liu committed
43
44
45
    unsigned mMyThreadOffsetA = 0;
    unsigned mMyThreadOffsetB = 0;

Chao Liu's avatar
Chao Liu committed
46
47
48
    struct MatrixIndex
    {
        unsigned batch_begin;
Chao Liu's avatar
Chao Liu committed
49
50
        unsigned row_begin;
        unsigned col_begin;
Chao Liu's avatar
Chao Liu committed
51
52
53
54
55
56
    };

    __device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c()
    {
        static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!");

Chao Liu's avatar
Chao Liu committed
57
58
59
#if 0
        constexpr auto a_block_desc = BlockMatrixA{};
        constexpr auto b_block_desc = BlockMatrixB{};
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
62
63
64
        constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread;
        constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread;
        constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread;
        constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread;
Chao Liu's avatar
Chao Liu committed
65

Chao Liu's avatar
Chao Liu committed
66
67
68
        constexpr auto a_thread_desc = ConstantMatrixDescriptor<a_thread_row, a_thread_col>{};
        constexpr auto b_thread_desc = ConstantMatrixDescriptor<b_thread_row, b_thread_col>{};
        constexpr auto c_thread_desc = ConstantMatrixDescriptor<MPerThread, NPerThread>{};
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
71
72
73
74
        constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol();
        constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow();

        constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol();
        constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow();
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79
80
81
82
83
84
85

        constexpr unsigned num_threads_per_row   = (m_block + m_thread - 1) / m_thread;
        constexpr unsigned num_threads_per_col   = (n_block + n_thread - 1) / n_thread;
        constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col;

        static_assert(BlockSize >= ((BatchSize + BatchPerThread - 1) / BatchPerThread) *
                                       num_threads_per_batch,
                      "not enough thread!");

        const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id());

Chao Liu's avatar
Chao Liu committed
86
87
88
89
90
91
        // mMyThreadOffsetA = xxx;
        // mMyThreadoffSetB = xxx;
#else
        mMyThreadOffsetA = 0;
        mMyThreadOffsetB = 0;
#endif
Chao Liu's avatar
Chao Liu committed
92
93
94
95
    }

    __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
    {
Chao Liu's avatar
Chao Liu committed
96
#if 0
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        constexpr auto a_block = BlockMatrixA{};
        constexpr auto b_block = BlockMatrixB{};
        constexpr auto c_block = BlockMatrixC{};

        constexpr auto a_thread = ThreadMatrixA{};
        constexpr auto b_thread = ThreadMatrixB{};
        constexpr auto c_thread = ThreadMatrixC{};

        constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol();
        constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow();

        constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol();
        constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow();

        constexpr unsigned num_threads_per_row   = (m_block + m_thread - 1) / m_thread;
        constexpr unsigned num_threads_per_col   = (n_block + n_thread - 1) / n_thread;
        constexpr unsigned num_threads_per_batch = num_threads_per_row * num_threads_per_col;

        // this is wrong, need fix
        const unsigned batch_begin = thread_id / (num_threads_per_batch)*BatchPerThread;
        const unsigned tmp = thread_id - batch_id * (num_threads_per_row * num_threads_per_col);
        const unsigned thread_matrix_row_id = tmp / num_threads_per_row;
        const unsigned thread_matrix_col_id = tmp - thread_matrix_row_id * num_threads_per_row;

        return MatrixIndex{
            batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread};
Chao Liu's avatar
Chao Liu committed
123
124
125
#else
        return MatrixIndex{0, 0, 0};
#endif
Chao Liu's avatar
Chao Liu committed
126
127
128
129
130
131
132
    }

    template <class FloatA, class FloatB, class FloatC>
    __device__ void run(FloatA* const p_a_block, FloatB* const p_b_block, FloatC* p_c_thread) const
    {
        // do something
    }
Chao Liu's avatar
Chao Liu committed
133
};