parallel_linear.cuh 4.81 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
#include "stream_manager.h"
#include "utils/cublas_wrapper.h"


5
6
7
/*
    This function is to be called with one block per each column
*/
Rick Ho's avatar
Rick Ho committed
8
template <typename scalar_t>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
    int m /* lines */, int n /* columns*/) {
    
    // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
    extern __shared__ unsigned char my_smem[];
    scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);

    // normal tid
    int tid = threadIdx.x + threadIdx.y * blockDim.x;
    
    // transposed tid for shared memory
    int new_tid = threadIdx.y + threadIdx.x * blockDim.y;

    // true x value in the matrix
    int real_x = threadIdx.x + blockDim.x * blockIdx.x;
    
    int i = real_x + n * threadIdx.y;
    const int it = n*blockDim.y;
    int offset = it;
    float accumulator = 0;

    if (threadIdx.y < m && real_x < n) {
        // store all the values from this column in a warped way
        accumulator = matrix[i];
        while (i + offset < n*m) {
            accumulator += matrix[i + offset];
            offset += it;
        }
    }

    // save column reduction data in a transposed way
    sdata[new_tid] = accumulator;
    __syncthreads();

    for (size_t t= 16; t > 0; t>>=1) {
        if (tid < 32 * 32 - 16)
            sdata[tid] += sdata[tid + t];
        __syncthreads();
    }
    
    if (threadIdx.y == 0 && real_x < n) 
        result[real_x] = sdata[new_tid];
    
}

template <typename scalar_t>
void fmoe_cuda_linear_forward_impl(
Rick Ho's avatar
Rick Ho committed
57
58
59
60
        const scalar_t* input_buf,
        const scalar_t* weight,
        const long* expert_count,
        scalar_t* output_buf,
61
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
62
63
64
65
        const size_t in_feat,
        const size_t out_feat,
        const size_t num_expert,
        CudaStreamManager* smgr) {
66
    scalar_t alpha = 1, beta = has_bias ? 1 : 0; 
Rick Ho's avatar
Rick Ho committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_T,
                CUBLAS_OP_N,
                out_feat, expert_count[i], in_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                input_buf + ptr * in_feat, in_feat,
                &beta,
                output_buf + out_feat * ptr, out_feat
                ));

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
}

template <typename scalar_t>
91
void fmoe_cuda_linear_backward_impl(
Rick Ho's avatar
Rick Ho committed
92
93
94
95
96
97
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
        const scalar_t* weight,
        const long* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
98
99
        scalar_t* grad_bias,
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
100
101
102
103
104
105
106
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
        const size_t num_expert,
        CudaStreamManager* smgr) {
    scalar_t alpha = 1, beta = 0;

107
108
109
110
111
    // bias
    dim3 block_threads(32, 32);
    dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
    

Rick Ho's avatar
Rick Ho committed
112
113
114
115
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
                    sizeof(scalar_t) * in_feat * out_feat);
116
            cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
Rick Ho's avatar
Rick Ho committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C

        // Backward input: g_i = w @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_N,
                in_feat, expert_count[i], out_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_input_buf + in_feat * ptr, in_feat
                ));

        // Backward weight: g_w = i @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_T,
                in_feat, out_feat, expert_count[i],
                &alpha,
                input_buf + in_feat * ptr, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_weight + i * in_feat * out_feat, in_feat
                ));
146
147
148
149
150
151
152
153
154
155
156
        
        if (has_bias) {
            column_reduce
            <<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(0)>>>
            (
                grad_output_buf + ptr * out_feat,
                grad_bias + i * out_feat,
                expert_count[i],
                out_feat
            );
        }
Rick Ho's avatar
Rick Ho committed
157
158
159
160
161

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
}
162