"lightx2v/models/vscode:/vscode.git/clone" did not exist on "3e67df1c040f57a7aa325cb2d0a4c2c47686604c"
ln_api.cpp 20.1 KB
Newer Older
1
2
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
3
#include <c10/cuda/CUDAGuard.h>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

#include "ln.h"

/*

Supported Type combinations:

input  residual   compute   weights   output
============================================
fp32     fp32      fp32      fp32      fp32
fp16     fp32      fp32      fp32      fp16
fp16     fp16      fp32      fp32      fp16
bf16     fp32      fp32      fp32      bf16
bf16     bf16      fp32      fp32      bf16
fp16     fp16      fp32      fp16      fp16
bf16     bf16      fp32      bf16      bf16

Remarks:
Output type = Input type
Compute always in FP32

*/

namespace layer_norm {

// Create registries and provide runtime versions of config hash functions.

FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS;

////////////////////////////////////////////////////////////////////////////////////////////////////

uint32_t get_type_id(torch::Dtype dtype){
    if( dtype == torch::kFloat16 ) {
        return TypeId<fp16>::Value;
    } else if( dtype == torch::kBFloat16 ) {
        return TypeId<bf16>::Value;
    } else if( dtype == torch::kFloat32 ) {
        return TypeId<fp32>::Value;
    } else {
        TORCH_CHECK(false, "Type not supported: ", dtype);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
    using namespace layer_norm;
    uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
    uint64_t launcher_key = (type_key << 32) | hidden_size;
    return launcher_key;
}

}  // namespace layer_norm

////////////////////////////////////////////////////////////////////////////////////////////////////

layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
    auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
    if( iter != layer_norm::FWD_FUNCS.end() ) {
        return iter->second;
    } else {
        TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
    auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
    if( iter != layer_norm::BWD_FUNCS.end() ) {
        return iter->second;
    } else {
        TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0,      // Input: BxSxhidden_size
                                           c10::optional<const at::Tensor> &x1_,      // Residual: BxSxhidden_size
                                           const at::Tensor &gamma,   // hidden_size
Tri Dao's avatar
Tri Dao committed
86
                                           c10::optional<const at::Tensor> &beta_,   // hidden_size
87
                                           c10::optional<const at::Tensor> &rowscale_,      // BxS
88
89
90
                                           c10::optional<const at::Tensor> &colscale_,      // hidden_size
                                           c10::optional<const at::Tensor> &x0_subset_,      // BxS
                                           c10::optional<const at::Tensor> &z_subset_,      // BxS
91
92
                                           const float dropout_p,
                                           const float epsilon,
93
94
                                           const float rowscale_const,
                                           const int64_t z_numrows,
95
                                           c10::optional<at::Generator> gen_,
Tri Dao's avatar
Tri Dao committed
96
97
                                           bool residual_in_fp32=false,
                                           bool is_rms_norm=false
98
99
100
101
102
103
104
105
106
107
108
109
110
111
) {
    auto itype = x0.scalar_type();
    auto rtype = x1_.has_value()
        ? x1_.value().scalar_type()
        : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
    auto wtype = gamma.scalar_type();
    auto otype = itype;
    auto ctype = torch::kFloat32;
    auto mtype = torch::kUInt8;

    TORCH_CHECK(x0.is_cuda())
    TORCH_CHECK(gamma.is_cuda())

    TORCH_CHECK(x0.is_contiguous());
112
113
114
115
116
117
    // c10::IntArrayRef does not own the storage, so we need to construct a vector.
    // Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
    // blah is then deallocated.
    std::vector<int64_t> sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)};
    auto sizes = c10::IntArrayRef(sizes_vec);
    TORCH_CHECK(x0.dim() == 2);
118
119
120
121
122
123
    TORCH_CHECK(sizes.size() == 2);

    const int rows = sizes[0];
    const int cols = sizes[1];
    auto hidden_size = gamma.numel();

Tri Dao's avatar
Tri Dao committed
124
125
126
127
128
129
130
131
    if (beta_.has_value()) {
        auto beta = beta_.value();
        TORCH_CHECK(beta.dtype() == wtype);
        TORCH_CHECK(beta.is_cuda())
        TORCH_CHECK(beta.is_contiguous());
        TORCH_CHECK(gamma.sizes() == beta.sizes());
    }

132
133
134
135
136
137
138
139
140
141
142
    if (x1_.has_value()) {
        auto x1 = x1_.value();
        TORCH_CHECK(x1.is_cuda())
        TORCH_CHECK(x1.is_contiguous());
        TORCH_CHECK(x1.sizes() == sizes);
    }

    if (rowscale_.has_value()) {
        auto rowscale = rowscale_.value();
        TORCH_CHECK(rowscale.is_cuda())
        TORCH_CHECK(rowscale.is_contiguous());
143
        TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
Tri Dao's avatar
Tri Dao committed
144
145
146
147
148
149
150
        TORCH_CHECK(rowscale.dtype() == itype);
    }

    if (colscale_.has_value()) {
        auto colscale = colscale_.value();
        TORCH_CHECK(colscale.is_cuda())
        TORCH_CHECK(colscale.is_contiguous());
151
        TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
Tri Dao's avatar
Tri Dao committed
152
        TORCH_CHECK(colscale.dtype() == wtype);
153
154
    }

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    if (x0_subset_.has_value()) {
        auto x0_subset = x0_subset_.value();
        TORCH_CHECK(x0_subset.is_cuda())
        TORCH_CHECK(x0_subset.is_contiguous());
        TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
        TORCH_CHECK(x0_subset.dtype() == torch::kInt32);

        TORCH_CHECK(z_subset_.has_value());
        auto z_subset = z_subset_.value();
        TORCH_CHECK(z_subset.is_cuda());
        TORCH_CHECK(z_subset.is_contiguous());
        TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
        TORCH_CHECK(z_subset.dtype() == torch::kInt32);
    }

170
    TORCH_CHECK(hidden_size == cols);
171
    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
172
173
174

    TORCH_CHECK(epsilon >= 0.f);

175
176
177
178
    // Otherwise the kernel will be launched from cuda:0 device
    // Cast to char to avoid compiler warning about narrowing
    at::cuda::CUDAGuard device_guard{(char)x0.get_device()};

179
180
    auto opts = x0.options();

181
    bool save_x = x1_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype);
182
183
184
    at::Tensor x;
    if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
    at::Tensor dmask;
185
186
    if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); };
    auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype));
187
188
189
190
191
192
193
194
195
196
197
198

    auto mu = torch::empty({ rows }, opts.dtype(ctype));
    auto rsigma = torch::empty({ rows }, opts.dtype(ctype));

    layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;

    launch_params.props = at::cuda::getCurrentDeviceProperties();
    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
    TORCH_CHECK(dropout_p < 1.f);
    launch_params.params.dropout_keep_p = 1.f - dropout_p;
    launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
Tri Dao's avatar
Tri Dao committed
199
    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
200
201
    launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
    launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
202
203
204
205

    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
        gen_, at::cuda::detail::getDefaultCUDAGenerator());

206
207
    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
208
    // Request the kernel launcher.
209
    auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    // Query the kernel-specific launch parameters.
    launcher(launch_params, true);

    at::Tensor workspace, barrier;

    // Set the kernel runtime parameters.
    layer_norm::FwdParams &params = launch_params.params;
    params.rows = rows;
    params.cols = cols;
    params.x0 = x0.data_ptr();
    params.x = save_x ? x.data_ptr() : nullptr;
    params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
    params.mu = mu.data_ptr();
    params.rs = rsigma.data_ptr();
    params.gamma = gamma.data_ptr();
Tri Dao's avatar
Tri Dao committed
226
    params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr;
227
228
229
    params.z = z.data_ptr();
    params.epsilon = epsilon;
    params.dropout_scale = 1.f / (1.f - dropout_p);
230
    params.inverse_cols = 1.f / float(params.cols);
231
    params.rowscale_const = rowscale_const;
Tri Dao's avatar
Tri Dao committed
232
    params.is_rms_norm = is_rms_norm;
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

    if (dropout_p > 0.f) {
        // number of times random will be generated per thread, to offset philox counter in thc random
        // state
        int64_t counter_offset = launch_params.elts_per_thread;

        // See Note [Acquire lock when using random generators]
        {
            std::lock_guard<std::mutex> lock(gen->mutex_);
            params.philox_args = gen->philox_cuda_state(counter_offset);
        }
    }

    if( launch_params.barrier_size > 0 ) {
        auto options = x0.options();
        barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
        workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
        params.workspace = workspace.data_ptr();
        params.barrier = barrier.data_ptr<int>();
    }

    // Launch the kernel.
    launcher(launch_params, false);

    return { z, x, dmask, mu, rsigma };
}

////////////////////////////////////////////////////////////////////////////////////////////////////

std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz,     // BxSxhidden_size
Tri Dao's avatar
Tri Dao committed
263
                                           c10::optional<const at::Tensor> &dx_,     // BxSxhidden_size
264
                                           const at::Tensor &x,      // BxSxhidden_size
Tri Dao's avatar
Tri Dao committed
265
                                           c10::optional<const at::Tensor> &x0_,     // BxSxhidden_size
266
267
268
269
270
                                           c10::optional<const at::Tensor> &dmask_,  // BxSxhidden_size
                                           const at::Tensor &mu,     // BxS, FP32!
                                           const at::Tensor &rsigma, // BxS, FP32!
                                           const at::Tensor &gamma,   // hidden_size
                                           c10::optional<const at::Tensor> &rowscale_,      // BxS
271
272
273
                                           c10::optional<const at::Tensor> &colscale_,      // hidden_size
                                           c10::optional<const at::Tensor> &x0_subset_,      // BxS
                                           c10::optional<const at::Tensor> &z_subset_,      // BxS
274
                                           const float dropout_p,
275
276
                                           const float rowscale_const,
                                           const int64_t x0_numrows,
Tri Dao's avatar
Tri Dao committed
277
278
                                           const bool has_residual,
                                           bool is_rms_norm=false
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
) {

    auto itype = dz.scalar_type();
    auto rtype = x.scalar_type();
    auto wtype = gamma.scalar_type();
    auto otype = itype;
    auto ctype = torch::kFloat32;
    auto mtype = torch::kUInt8;

    if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }

    TORCH_CHECK(dz.dtype() == otype);
    TORCH_CHECK(mu.dtype() == ctype);
    TORCH_CHECK(rsigma.dtype() == ctype);

    TORCH_CHECK(x.is_cuda());
    TORCH_CHECK(dz.is_cuda());
    TORCH_CHECK(mu.is_cuda());
    TORCH_CHECK(rsigma.is_cuda());
    TORCH_CHECK(gamma.is_cuda());

    TORCH_CHECK(x.is_contiguous());
    TORCH_CHECK(dz.is_contiguous());

    auto sizes = x.sizes();
    TORCH_CHECK(sizes.size() == 2);
    auto rows = sizes[0];
    auto cols = sizes[1];
307
308
309
310
311
312
313
314
    TORCH_CHECK(dz.dim() == 2);
    TORCH_CHECK(dz.size(1) == cols);

    // c10::IntArrayRef does not own the storage, so we need to construct a vector.
    // Otherwise just constructing IntArrayRef({blah}) will cause unintialized memory because
    // blah is then deallocated.
    std::vector<int64_t> x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols};
    auto x0_sizes = c10::IntArrayRef(x0_sizes_vec);
315

Tri Dao's avatar
Tri Dao committed
316
317
318
319
320
321
    if (dx_.has_value()) {
        auto dx = dx_.value();
        TORCH_CHECK(dx.dtype() == rtype);
        TORCH_CHECK(dx.is_cuda())
        TORCH_CHECK(dx.is_contiguous());
        TORCH_CHECK(dx.sizes() == sizes);
322
323
324
325
326
327
328
    }

    if (dmask_.has_value()) {
        auto dmask = dmask_.value();
        TORCH_CHECK(dmask.dtype() == mtype);
        TORCH_CHECK(dmask.is_cuda());
        TORCH_CHECK(dmask.is_contiguous());
329
        TORCH_CHECK(dmask.sizes() == x0_sizes);
330
331
332
333
334
335
    }

    if (rowscale_.has_value()) {
        auto rowscale = rowscale_.value();
        TORCH_CHECK(rowscale.is_cuda())
        TORCH_CHECK(rowscale.is_contiguous());
336
        TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows});
Tri Dao's avatar
Tri Dao committed
337
338
339
340
341
342
343
        TORCH_CHECK(rowscale.dtype() == itype);
    }

    if (colscale_.has_value()) {
        auto colscale = colscale_.value();
        TORCH_CHECK(colscale.is_cuda())
        TORCH_CHECK(colscale.is_contiguous());
344
        TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols});
Tri Dao's avatar
Tri Dao committed
345
346
347
348
349
350
        TORCH_CHECK(colscale.dtype() == wtype);

        TORCH_CHECK(x0_.has_value());
        auto x0 = x0_.value();
        TORCH_CHECK(x0.is_cuda())
        TORCH_CHECK(x0.is_contiguous());
351
        TORCH_CHECK(x0.sizes() == x0_sizes);
Tri Dao's avatar
Tri Dao committed
352
        TORCH_CHECK(x0.dtype() == itype);
353
354
    }

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    if (x0_subset_.has_value()) {
        auto x0_subset = x0_subset_.value();
        TORCH_CHECK(x0_subset.is_cuda())
        TORCH_CHECK(x0_subset.is_contiguous());
        TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows});
        TORCH_CHECK(x0_subset.dtype() == torch::kInt32);

        TORCH_CHECK(z_subset_.has_value());
        auto z_subset = z_subset_.value();
        TORCH_CHECK(z_subset.is_cuda());
        TORCH_CHECK(z_subset.is_contiguous());
        TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows});
        TORCH_CHECK(z_subset.dtype() == torch::kInt32);
    }

370
    auto hidden_size = gamma.numel();
371
372
    TORCH_CHECK(hidden_size == cols);
    TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
373
374
375
376
377
378

    TORCH_CHECK(mu.numel() == rows);
    TORCH_CHECK(mu.sizes() == rsigma.sizes());

    TORCH_CHECK(gamma.numel() == cols);

379
380
381
382
    // Otherwise the kernel will be launched from cuda:0 device
    // Cast to char to avoid compiler warning about narrowing
    at::cuda::CUDAGuard device_guard{(char)dz.get_device()};

383
384
    auto opts = x.options();

385
    auto dx0 = torch::empty(x0_sizes, opts.dtype(itype));
386
387
388
389
    at::Tensor dx1;
    if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
    auto dgamma = torch::empty_like(gamma);
    auto dbeta = torch::empty_like(gamma);
Tri Dao's avatar
Tri Dao committed
390
391
392
393
    at::Tensor dcolscale;
    if (colscale_.has_value()) {
        dcolscale = torch::empty_like(colscale_.value());
    }
394
395
396
397
398
399
400
401

    layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
    launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
    launch_params.props = at::cuda::getCurrentDeviceProperties();
    TORCH_CHECK(dropout_p < 1.f);
    launch_params.params.dropout_keep_p = 1.f - dropout_p;
    launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
    launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
Tri Dao's avatar
Tri Dao committed
402
    launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr;
403
404
    launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr;
    launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr;
405

406
407
408
    auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
    const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
    auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
409

Tri Dao's avatar
Tri Dao committed
410
    launcher(launch_params, true);
411
412
413

    auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
    auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
Tri Dao's avatar
Tri Dao committed
414
415
416
417
    at::Tensor dcolscale_part;
    if (colscale_.has_value()) {
        dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
    }
418
419
420
421
422
423
    at::Tensor workspace, barrier;

    layer_norm::BwdParams &params = launch_params.params;
    params.rows = rows;
    params.cols = cols;
    params.x = x.data_ptr();
Tri Dao's avatar
Tri Dao committed
424
    params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr;
425
426
427
428
429
    params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
    params.mu = mu.data_ptr();
    params.rs = rsigma.data_ptr();
    params.gamma = gamma.data_ptr();
    params.dz = dz.data_ptr();
Tri Dao's avatar
Tri Dao committed
430
    params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr;
431
432
433
    params.dx0 = dx0.data_ptr();
    params.dbeta = dbeta.data_ptr();
    params.dgamma = dgamma.data_ptr();
Tri Dao's avatar
Tri Dao committed
434
    params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr;
435
436
    params.dbeta_part = dbeta_part.data_ptr();
    params.dgamma_part = dgamma_part.data_ptr();
Tri Dao's avatar
Tri Dao committed
437
    params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr;
438
    params.dropout_scale = 1.f / (1.f - dropout_p);
439
    params.inverse_cols = 1.f / float(params.cols);
440
    params.rowscale_const = rowscale_const;
Tri Dao's avatar
Tri Dao committed
441
    params.is_rms_norm = is_rms_norm;
442
443
444
445
446
447
448
449
450

    if( launch_params.barrier_size > 0 ) {
        // TODO Any way to avoid this?
        barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
        workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
        params.workspace = workspace.data_ptr();
        params.barrier = barrier.data_ptr<int>();
    }

Tri Dao's avatar
Tri Dao committed
451
    launcher(launch_params, false);
452

Tri Dao's avatar
Tri Dao committed
453
454
455
456
457
458
    std::vector<at::Tensor> result = { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
    if (colscale_.has_value()) {
        result.push_back(dcolscale);
        result.push_back(dcolscale_part);
    }
    return result;
459
460
461
462
463
}
////////////////////////////////////////////////////////////////////////////////////////////////////

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.doc() = "CUDA DropoutAddLayerNorm";
Tri Dao's avatar
Tri Dao committed
464
465
466
467
468
469
470
471
472
473
  m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel",
        py::arg("x0"), py::arg("x1"), py::arg("gamma"), py::arg("beta"),
        py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"),
        py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"),
        py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false);
  m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel",
        py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"),
        py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"),
        py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"),
        py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false);
474
}