parallel_linear.cuh 4.89 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
#include "stream_manager.h"
#include "utils/cublas_wrapper.h"


5
6
7
/*
    This function is to be called with one block per each column
*/
Rick Ho's avatar
Rick Ho committed
8
template <typename scalar_t>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
__global__ 
void column_reduce(const scalar_t * matrix, scalar_t * result, 
    int m /* lines */, int n /* columns*/) {
    
    // https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
    extern __shared__ unsigned char my_smem[];
    scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);

    // normal tid
    int tid = threadIdx.x + threadIdx.y * blockDim.x;
    
    // transposed tid for shared memory
    int new_tid = threadIdx.y + threadIdx.x * blockDim.y;

    // true x value in the matrix
    int real_x = threadIdx.x + blockDim.x * blockIdx.x;
    
    int i = real_x + n * threadIdx.y;
    const int it = n*blockDim.y;
    int offset = it;
    float accumulator = 0;

    if (threadIdx.y < m && real_x < n) {
        // store all the values from this column in a warped way
        accumulator = matrix[i];
        while (i + offset < n*m) {
            accumulator += matrix[i + offset];
            offset += it;
        }
    }

    // save column reduction data in a transposed way
    sdata[new_tid] = accumulator;
    __syncthreads();

    for (size_t t= 16; t > 0; t>>=1) {
        if (tid < 32 * 32 - 16)
            sdata[tid] += sdata[tid + t];
        __syncthreads();
    }
    
    if (threadIdx.y == 0 && real_x < n) 
        result[real_x] = sdata[new_tid];
    
}

template <typename scalar_t>
void fmoe_cuda_linear_forward_impl(
Rick Ho's avatar
Rick Ho committed
57
58
59
60
        const scalar_t* input_buf,
        const scalar_t* weight,
        const long* expert_count,
        scalar_t* output_buf,
61
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
62
63
64
65
        const size_t in_feat,
        const size_t out_feat,
        const size_t num_expert,
        CudaStreamManager* smgr) {
66
    scalar_t alpha = 1, beta = has_bias ? 1 : 0; 
Rick Ho's avatar
Rick Ho committed
67

Rick Ho's avatar
Rick Ho committed
68
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
69
70
71
72
73
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C
zhanggzh's avatar
zhanggzh committed
74
        //change alpha beta dtype
Rick Ho's avatar
Rick Ho committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_T,
                CUBLAS_OP_N,
                out_feat, expert_count[i], in_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                input_buf + ptr * in_feat, in_feat,
                &beta,
                output_buf + out_feat * ptr, out_feat
                ));

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
}

template <typename scalar_t>
93
void fmoe_cuda_linear_backward_impl(
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
        const scalar_t* grad_output_buf,
        const scalar_t* input_buf,
        const scalar_t* weight,
        const long* expert_count,
        scalar_t* grad_input_buf,
        scalar_t* grad_weight,
100
101
        scalar_t* grad_bias,
        const bool has_bias,
Rick Ho's avatar
Rick Ho committed
102
103
104
105
106
        const size_t batch_size,
        const size_t in_feat,
        const size_t out_feat,
        const size_t num_expert,
        CudaStreamManager* smgr) {
Rick Ho's avatar
Rick Ho committed
107
    smgr->syncTorch();
Rick Ho's avatar
Rick Ho committed
108
109
    scalar_t alpha = 1, beta = 0;

110
111
112
113
114
    // bias
    dim3 block_threads(32, 32);
    dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
    

Rick Ho's avatar
Rick Ho committed
115
116
117
118
    for (int i = 0, ptr = 0; i < num_expert; ++i) {
        if (expert_count[i] == 0) {
            cudaMemset(grad_weight + i * in_feat * out_feat, 0, 
                    sizeof(scalar_t) * in_feat * out_feat);
119
            cudaMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
Rick Ho's avatar
Rick Ho committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            continue;
        }
        // Use T(B) x T(A) = T(C) to produce row-major C

        // Backward input: g_i = w @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_N,
                in_feat, expert_count[i], out_feat,
                &alpha,
                weight + i * in_feat * out_feat, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_input_buf + in_feat * ptr, in_feat
                ));

        // Backward weight: g_w = i @ g_o
        checkCudaErrors(cublasXgemm(
                smgr->handle(i),
                CUBLAS_OP_N,
                CUBLAS_OP_T,
                in_feat, out_feat, expert_count[i],
                &alpha,
                input_buf + in_feat * ptr, in_feat,
                grad_output_buf + ptr * out_feat, out_feat,
                &beta,
                grad_weight + i * in_feat * out_feat, in_feat
                ));
149
150
151
        
        if (has_bias) {
            column_reduce
152
            <<<grid_threads, block_threads, sizeof(scalar_t)*1024, smgr->stream(i)>>>
153
154
155
156
157
158
159
            (
                grad_output_buf + ptr * out_feat,
                grad_bias + i * out_feat,
                expert_count[i],
                out_feat
            );
        }
Rick Ho's avatar
Rick Ho committed
160
161
162
163
164

        ptr += expert_count[i];
    }
    smgr->sync(num_expert);
}
165