ln.h 7.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#pragma once

#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

namespace layer_norm {

////////////////////////////////////////////////////////////////////////////////////////////////////

17
template<typename Params>
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct LaunchParams{

    size_t elts_per_thread;
    size_t workspace_bytes;
    size_t barrier_size;

    cudaDeviceProp * props;

    cudaStream_t stream;

    Params params;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct ParamsBase {
    ParamsBase()
        : ctas_per_col(0)
        , rows(0)
        , cols(0)
        , x(nullptr)
        , mu(nullptr)
        , rs(nullptr)
        , gamma(nullptr)
43
        , gamma1(nullptr)
Tri Dao's avatar
Tri Dao committed
44
45
        , rowscale(nullptr)
        , colscale(nullptr)
46
47
        , dropout_keep_p(1.f)
        , dropout_scale(1.f)
Tri Dao's avatar
Tri Dao committed
48
        , is_rms_norm(false)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        , workspace(nullptr)
        , barrier(nullptr)
    {
    }

    // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
    int ctas_per_col;

    // Input is interpreted as matrix. We normalize across columns.
    int rows;
    int cols;

    // Common data pointers.
    void *x0;
63
    void *x1;
Tri Dao's avatar
Tri Dao committed
64
    void *residual;
65
66
    void *x;
    void *dmask;
67
    void *dmask1;
68
69
70
    void *mu;
    void *rs;
    void *gamma;
71
    void *gamma1;
72
    void *rowscale;
Tri Dao's avatar
Tri Dao committed
73
    void *colscale;
74
75
    void *x0_subset;
    void *z_subset;
76

77
78
    float inverse_cols;

79
80
    float dropout_keep_p;
    float dropout_scale;
81
    float rowscale_const;
82

Tri Dao's avatar
Tri Dao committed
83
84
    bool is_rms_norm;

85
86
87
88
89
90
91
92
93
94
95
96
97
98
    // Multi-CTA workspace in gmem.
    void *workspace;

    // Multi-CTA sync barriers in gmem.
    int *barrier;

};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct FwdParams : public ParamsBase {
    FwdParams()
        : ParamsBase()
        , z(nullptr)
99
        , z1(nullptr)
100
        , beta(nullptr)
101
        , beta1(nullptr)
102
103
104
105
106
107
        , epsilon(0.f)
    {
    }

    // Output of LN FWD.
    void *z;
108
    void *z1;
109
    void *beta;
110
    void *beta1;
111
112
113
114
115
116
117
118
119
120
121
122
    float epsilon;

    // Random state.
    at::PhiloxCudaState philox_args;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct BwdParams : public ParamsBase {
    BwdParams()
        : ParamsBase()
        , dz(nullptr)
123
        , dz1(nullptr)
124
125
126
        , dx(nullptr)
        , dbeta_part(nullptr)
        , dgamma_part(nullptr)
127
128
        , dbeta1_part(nullptr)
        , dgamma1_part(nullptr)
Tri Dao's avatar
Tri Dao committed
129
        , dcolscale_part(nullptr)
130
        , dx0(nullptr)
131
        , dx1(nullptr)
Tri Dao's avatar
Tri Dao committed
132
        , dresidual(nullptr)
133
134
        , dbeta(nullptr)
        , dgamma(nullptr)
135
136
        , dbeta1(nullptr)
        , dgamma1(nullptr)
Tri Dao's avatar
Tri Dao committed
137
        , dcolscale(nullptr)
138
139
140
141
142
    {
    }

    // Input: gradient wrt. LN FWD output.
    void *dz;
143
    void *dz1;
144
145
146
147
148
149
    // Input: gradient wrt residual.
    void *dx;

    // Workspace for Wgrad pre-reduction.
    void *dbeta_part;
    void *dgamma_part;
150
151
    void *dbeta1_part;
    void *dgamma1_part;
Tri Dao's avatar
Tri Dao committed
152
    void *dcolscale_part;
153
154
155

    // Output: Dgrad.
    void *dx0;
156
    void *dx1;
Tri Dao's avatar
Tri Dao committed
157
    void *dresidual;
158
159
160
    // Output: Wgrad.
    void *dbeta;
    void *dgamma;
161
162
    void *dbeta1;
    void *dgamma1;
Tri Dao's avatar
Tri Dao committed
163
    void *dcolscale;
164
165
166
167
168
169

};

////////////////////////////////////////////////////////////////////////////////////////////////////

using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
Tri Dao's avatar
Tri Dao committed
170
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool)>;
171
172
173
174
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;

175
176
extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS;
extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

////////////////////////////////////////////////////////////////////////////////////////////////////

using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
struct TypeId{};

template<>
struct TypeId<fp16>{
    constexpr static uint32_t Value = 0;
};

template<>
struct TypeId<bf16>{
    constexpr static uint32_t Value = 1;
};

template<>
struct TypeId<fp32>{
    constexpr static uint32_t Value = 2;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int S>
struct Type2Key{
    constexpr static uint32_t Value = TypeId<T>::Value << S;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};

template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};

template<typename T>
struct ResidualType2Key : public Type2Key<T, 4>{};

template<typename T>
struct OutputType2Key : public Type2Key<T, 6>{};

template<typename T>
struct ComputeType2Key : public Type2Key<T, 8>{};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename W, typename I, typename R, typename O, typename C>
struct Types2Key{
    constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
    constexpr static inline uint64_t get(const uint64_t hidden_size){
        constexpr uint64_t type_key = Value;
        return (type_key << 32) | hidden_size;
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
    FwdRegistrar(FwdFunction f){
        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
        FWD_FUNCS.insert({ key, f });
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
    BwdRegistrar(BwdFunction f){
        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
        BWD_FUNCS.insert({ key, f });
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdParallelRegistrar{
    FwdParallelRegistrar(FwdFunction f){
        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
        PARALLEL_FWD_FUNCS.insert({ key, f });
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdParallelRegistrar{
    BwdParallelRegistrar(BwdFunction f){
        uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
        PARALLEL_BWD_FUNCS.insert({ key, f });
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

281
}  // namespace layer_norm