common.cu 23.5 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"
#include "transformer_engine/transformer_engine.h"


transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
                                                      const std::string &fp8_recipe) {
    // if e4m3 or hybrid + forward
    if ( (fp8_recipe == "E4M3") || ( (fp8_recipe == "HYBRID") && e4m3_if_hybrid ) ) {
        return transformer_engine::DType::kFloat8E4M3;
    }
    return transformer_engine::DType::kFloat8E5M2;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr,
    const NVTEShape& shape,
    const transformer_engine::DType type) {
  return transformer_engine::TensorWrapper(data_ptr, shape, type);
}


transformer_engine::TensorWrapper makeTransformerEngineTensor(
    void* data_ptr,
    const std::vector<size_t>& shape,
    const transformer_engine::DType type) {
  return transformer_engine::TensorWrapper(data_ptr, shape, type);
}


transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
    transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
    std::vector<size_t> shape;

    for (auto s : tensor.sizes()) {
        shape.push_back(s);
    }
    return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}


size_t product(const std::vector<size_t> &shape) {
    size_t ret = 1;
    for (auto s : shape) {
        ret *= s;
    }
    return ret;
}


at::Tensor allocateSpace(const NVTEShape &shape,
                         const transformer_engine::DType type,
                         bool init_to_zeros) {
    auto size = shape.ndim;
    if (size == 2 && init_to_zeros) {
        return at::zeros({static_cast<int64_t>(shape.data[0]),
                          static_cast<int64_t>(shape.data[1])},
                          at::CUDA(GetATenDType(type)));
    } else if (size == 2) {
        return at::empty({static_cast<int64_t>(shape.data[0]),
                          static_cast<int64_t>(shape.data[1])},
                          at::CUDA(GetATenDType(type)));
    } else if (size == 1 && init_to_zeros) {
        return at::zeros({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
    } else if (size == 1) {
        return at::empty({static_cast<int64_t>(shape.data[0])}, at::CUDA(GetATenDType(type)));
    }
    NVTE_CHECK(false, "Should never reach here! func: allocateSpace");
}


at::Tensor allocateTorchTensor(int M,
                               int N,
                               transformer_engine::DType dtype
) {
    return at::empty({static_cast<int64_t>(M), static_cast<int64_t>(N)},
                     at::CUDA(GetATenDType(dtype)));
}


at::Tensor allocateTorchTensor(int M,
                               transformer_engine::DType dtype
) {
    return at::empty({static_cast<int64_t>(M)},
                     at::CUDA(GetATenDType(dtype)));
}


void dispatch_layernorm(void* input,                                    // i
                        const std::vector<size_t>& input_shape,
                        const transformer_engine::DType input_type,
                        void* gamma,                                    // i
                        const std::vector<size_t>& gamma_shape,
                        const transformer_engine::DType gamma_type,
                        void* beta,                                     // i
                        const std::vector<size_t>& beta_shape,
                        const transformer_engine::DType beta_type,
                        void* scale,                                    // i
                        const std::vector<size_t>& scale_shape,
                        const transformer_engine::DType scale_type,
                        const float epsilon,                            // i
                        void* z,                                        // o
                        const std::vector<size_t>& z_shape,
                        const transformer_engine::DType z_type,
                        void* mu,                                       // o
                        const std::vector<size_t>& mu_shape,
                        const transformer_engine::DType mu_type,
                        void* rsigma,                                   // o
                        const std::vector<size_t>& rsigma_shape,
                        const transformer_engine::DType rsigma_type,
                        void* amax,                                     // o
                        const std::vector<size_t>& amax_shape,
                        const transformer_engine::DType amax_type,
                        void* scale_inv,                                // o
                        const std::vector<size_t>& scale_inv_shape,
                        const transformer_engine::DType scale_inv_type,
122
                        const int multiProcessorCount
Przemek Tredak's avatar
Przemek Tredak committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
) {
    auto input_cu     = makeTransformerEngineTensor(input, input_shape, input_type);
    auto gamma_cu     = makeTransformerEngineTensor(gamma, gamma_shape, gamma_type);
    auto beta_cu      = makeTransformerEngineTensor(beta, beta_shape, beta_type);
    auto scale_cu     = makeTransformerEngineTensor(scale, scale_shape, scale_type);
    auto z_cu         = makeTransformerEngineTensor(z, z_shape, z_type);
    auto mu_cu        = makeTransformerEngineTensor(mu, mu_shape, mu_type);
    auto rsigma_cu    = makeTransformerEngineTensor(rsigma, rsigma_shape, rsigma_type);
    auto amax_cu      = makeTransformerEngineTensor(amax, amax_shape, amax_type);
    auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);
    transformer_engine::TensorWrapper workspace, barrier;

    // This call populates workspace and barrier tensors with the required config
    nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
                       scale_cu.data(), epsilon,
                       z_cu.data(), mu_cu.data(), rsigma_cu.data(),
                       at::cuda::getCurrentCUDAStream(), multiProcessorCount,
                       workspace.data(), barrier.data(), amax_cu.data(),
141
                       scale_inv_cu.data());
Przemek Tredak's avatar
Przemek Tredak committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

    // Fill workspace and barrier
    auto workspace_data = allocateSpace(workspace.shape(),
                                        workspace.dtype());
    auto barrier_data = allocateSpace(barrier.shape(),
                                      barrier.dtype(),
                                      true);
    workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
                                            workspace.shape(),
                                            workspace.dtype());
    barrier   = makeTransformerEngineTensor(barrier_data.data_ptr(),
                                            barrier.shape(),
                                            barrier.dtype());

    // Actual call to fwd kernel
    nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(),
                       scale_cu.data(), epsilon,
                       z_cu.data(), mu_cu.data(), rsigma_cu.data(),
                       at::cuda::getCurrentCUDAStream(), multiProcessorCount,
                       workspace.data(), barrier.data(), amax_cu.data(),
162
                       scale_inv_cu.data());
Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
}


void dispatch_cast_transpose_fusion(void* input,                                            // i
                                    const std::vector<size_t>& input_shape,
                                    const transformer_engine::DType input_type,
                                    void* scale,                                            // i
                                    const std::vector<size_t>& scale_shape,
                                    const transformer_engine::DType scale_type,
                                    void* output_cast,                                      // o
                                    const std::vector<size_t>& output_cast_shape,
                                    const transformer_engine::DType output_cast_type,
                                    void* output_transpose,                                 // o
                                    const std::vector<size_t>& output_transpose_shape,
                                    const transformer_engine::DType output_transpose_type,
                                    void* amax,                                             // o
                                    const std::vector<size_t>& amax_shape,
                                    const transformer_engine::DType amax_type,
                                    void* scale_inv,                                        // o
                                    const std::vector<size_t>& scale_inv_shape,
                                    const transformer_engine::DType scale_inv_type
) {
    auto input_cu            = makeTransformerEngineTensor(input, input_shape, input_type);
    auto output_cast_cu      = makeTransformerEngineTensor(output_cast, output_cast_shape,
                                                           output_cast_type);
    auto output_transpose_cu = makeTransformerEngineTensor(output_transpose, output_transpose_shape,
                                                           output_transpose_type);
    auto scale_cu            = makeTransformerEngineTensor(scale, scale_shape, scale_type);
    auto amax_cu             = makeTransformerEngineTensor(amax, amax_shape, amax_type);
    auto scale_inv_cu        = makeTransformerEngineTensor(scale_inv, scale_inv_shape,
                                                           scale_inv_type);

    nvte_cast_transpose(input_cu.data(), scale_cu.data(),
                        output_cast_cu.data(), output_transpose_cu.data(),
                        amax_cu.data(), scale_inv_cu.data(),
                        at::cuda::getCurrentCUDAStream());
}


void dispatch_gelu(void* input,                                            // i
                   const std::vector<size_t>& input_shape,
                   const transformer_engine::DType input_type,
                   void* scale,                                            // i
                   const std::vector<size_t>& scale_shape,
                   const transformer_engine::DType scale_type,
                   void* output,                                           // o
                   const std::vector<size_t>& output_shape,
                   const transformer_engine::DType output_type,
                   void* amax,                                             // o
                   const std::vector<size_t>& amax_shape,
                   const transformer_engine::DType amax_type,
                   void* scale_inv,                                        // o
                   const std::vector<size_t>& scale_inv_shape,
                   const transformer_engine::DType scale_inv_type
) {
    auto input_cu =     makeTransformerEngineTensor(input, input_shape, input_type);
    auto output_cu =    makeTransformerEngineTensor(output, output_shape, output_type);
    auto scale_cu =     makeTransformerEngineTensor(scale, scale_shape, scale_type);
    auto amax_cu =      makeTransformerEngineTensor(amax, amax_shape, amax_type);
    auto scale_inv_cu = makeTransformerEngineTensor(scale_inv, scale_inv_shape, scale_inv_type);

    nvte_gelu(input_cu.data(), output_cu.data(), scale_cu.data(),
              amax_cu.data(), scale_inv_cu.data(), at::cuda::getCurrentCUDAStream());
}


void dispatch_transpose(void* input,                                            // i
                        const std::vector<size_t>& input_shape,
                        const transformer_engine::DType input_type,
                        void* output,                                           // o
                        const std::vector<size_t>& output_shape,
                        const transformer_engine::DType output_type
) {
    auto input_cu  = makeTransformerEngineTensor(input, input_shape, input_type);
    auto output_cu = makeTransformerEngineTensor(output, output_shape, output_type);

    nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
}


void dispatch_bgrad_cast_transpose_fusion(void* input,                                          // i
                                          const std::vector<size_t>& input_shape,
                                          const transformer_engine::DType input_type,
                                          void* scale,                                          // i
                                          const std::vector<size_t>& scale_shape,
                                          const transformer_engine::DType scale_type,
                                          void* cast_output,                                    // o
                                          const std::vector<size_t>& cast_output_shape,
                                          const transformer_engine::DType cast_output_type,
                                          void* transposed_output,                              // o
                                          const std::vector<size_t>& transposed_output_shape,
                                          const transformer_engine::DType transposed_output_type,
                                          void* amax,                                           // o
                                          const std::vector<size_t>& amax_shape,
                                          const transformer_engine::DType amax_type,
                                          void* dbias,                                          // o
                                          const std::vector<size_t>& dbias_shape,
                                          const transformer_engine::DType dbias_type,
                                          void* scale_inv,                                      // o
                                          const std::vector<size_t>& scale_inv_shape,
                                          const transformer_engine::DType scale_inv_type
) {
  auto input_cu             = makeTransformerEngineTensor(input, input_shape, input_type);
  auto scale_cu             = makeTransformerEngineTensor(scale, scale_shape, scale_type);
  auto cast_output_cu       = makeTransformerEngineTensor(cast_output, cast_output_shape,
                                                      cast_output_type);
  auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
                                                          transposed_output_shape,
                                                          transposed_output_type);
  auto amax_cu              = makeTransformerEngineTensor(amax, amax_shape, amax_type);
  auto dbias_cu             = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
  auto scale_inv_cu         = makeTransformerEngineTensor(scale_inv,
                                                          scale_inv_shape,
                                                          scale_inv_type);
  transformer_engine::TensorWrapper workspace;

  nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
                            transposed_output_cu.data(), amax_cu.data(),
                            dbias_cu.data(), scale_inv_cu.data(),
                            workspace.data(), at::cuda::getCurrentCUDAStream());

  // Fill workspace
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
                                          workspace.shape(),
                                          workspace.dtype());

  nvte_cast_transpose_dbias(input_cu.data(), scale_cu.data(), cast_output_cu.data(),
                            transposed_output_cu.data(), amax_cu.data(),
                            dbias_cu.data(), scale_inv_cu.data(), workspace.data(),
                            at::cuda::getCurrentCUDAStream());
}


void dispatch_bgrad_dgelu_cast_transpose_fusion(
    void* input,                                            // i
    const std::vector<size_t>& input_shape,
    const transformer_engine::DType input_type,
    void* gelu_input,                                       // i
    const std::vector<size_t>& gelu_input_shape,
    const transformer_engine::DType gelu_input_type,
    void* scale,                                            // i
    const std::vector<size_t>& scale_shape,
    const transformer_engine::DType scale_type,
    void* cast_output,                                      // o
    const std::vector<size_t>& cast_output_shape,
    const transformer_engine::DType cast_output_type,
    void* transposed_output,                                // o
    const std::vector<size_t>& transposed_output_shape,
    const transformer_engine::DType transposed_output_type,
    void* amax,                                             // o
    const std::vector<size_t>& amax_shape,
    const transformer_engine::DType amax_type,
    void* dbias,                                            // o
    const std::vector<size_t>& dbias_shape,
    const transformer_engine::DType dbias_type,
    void* scale_inv,                                        // o
    const std::vector<size_t>& scale_inv_shape,
    const transformer_engine::DType scale_inv_type
) {
  transformer_engine::TensorWrapper workspace;
  auto gelu_input_cu        = makeTransformerEngineTensor(gelu_input, gelu_input_shape,
                                                          gelu_input_type);
  auto input_cu             = makeTransformerEngineTensor(input, input_shape, input_type);
  auto scale_cu             = makeTransformerEngineTensor(scale, scale_shape, scale_type);
  auto cast_output_cu       = makeTransformerEngineTensor(cast_output, cast_output_shape,
                                                          cast_output_type);
  auto transposed_output_cu = makeTransformerEngineTensor(transposed_output,
                                                          transposed_output_shape,
                                                          transposed_output_type);
  auto amax_cu              = makeTransformerEngineTensor(amax, amax_shape, amax_type);
  auto dbias_cu             = makeTransformerEngineTensor(dbias, dbias_shape, dbias_type);
  auto scale_inv_cu         = makeTransformerEngineTensor(scale_inv,
                                                          scale_inv_shape,
                                                          scale_inv_type);

  nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
                                  cast_output_cu.data(), transposed_output_cu.data(),
                                  amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
                                  workspace.data(), at::cuda::getCurrentCUDAStream());

  // Fill workspace
  auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
  workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
                                          workspace.shape(),
                                          workspace.dtype());

  nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), scale_cu.data(),
                                  cast_output_cu.data(), transposed_output_cu.data(),
                                  amax_cu.data(), dbias_cu.data(), scale_inv_cu.data(),
                                  workspace.data(), at::cuda::getCurrentCUDAStream());
}
Tim Moon's avatar
Tim Moon committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432


void dispatch_multi_cast_transpose(
    std::vector<void*> input_dptr_list,                     // i
    const std::vector<std::vector<size_t>>& input_shape_list,
    const std::vector<transformer_engine::DType>& input_type_list,
    std::vector<void*> scale_dptr_list,                     // i
    const std::vector<std::vector<size_t>>& scale_shape_list,
    const std::vector<transformer_engine::DType>& scale_type_list,
    std::vector<void*> cast_output_dptr_list,               // o
    const std::vector<std::vector<size_t>>& cast_output_shape_list,
    const std::vector<transformer_engine::DType>& cast_output_type_list,
    std::vector<void*> transposed_output_dptr_list,         // o
    const std::vector<std::vector<size_t>>& transposed_output_shape_list,
    const std::vector<transformer_engine::DType>& transposed_output_type_list,
    std::vector<void*> amax_dptr_list,                      // o
    const std::vector<std::vector<size_t>>& amax_shape_list,
    const std::vector<transformer_engine::DType>& amax_type_list,
    std::vector<void*> scale_inv_dptr_list,                 // o
    const std::vector<std::vector<size_t>>& scale_inv_shape_list,
    const std::vector<transformer_engine::DType>& scale_inv_type_list
) {
  transformer_engine::TensorWrapper workspace;

  // Construct TE tensors
  std::vector<NVTETensor> input_list, scale_list,
    cast_output_list, transposed_output_list, amax_list, scale_inv_list;
  std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
  auto make_tensor = [&tensor_wrappers](void* dptr,
                                        const std::vector<size_t>& shape,
                                        transformer_engine::DType dtype)
    -> NVTETensor {
    tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype));
    return tensor_wrappers.back().data();
  };
  for (size_t i = 0; i < input_dptr_list.size(); ++i) {
    input_list.emplace_back(make_tensor(input_dptr_list[i],
                                        input_shape_list[i],
                                        input_type_list[i]));
    scale_list.emplace_back(make_tensor(scale_dptr_list[i],
                                        scale_shape_list[i],
                                        scale_type_list[i]));
    cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
                                              cast_output_shape_list[i],
                                              cast_output_type_list[i]));
    transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
                                                    transposed_output_shape_list[i],
                                                    transposed_output_type_list[i]));
    amax_list.emplace_back(make_tensor(amax_dptr_list[i],
                                       amax_shape_list[i],
                                       amax_type_list[i]));
    scale_inv_list.emplace_back(make_tensor(scale_inv_dptr_list[i],
                                            scale_inv_shape_list[i],
                                            scale_inv_type_list[i]));
  }

  // Check tensor lists
  NVTE_CHECK(scale_list.size() == input_list.size(),
             "Number of input and scale tensors must match");
  NVTE_CHECK(cast_output_list.size() == input_list.size(),
             "Number of input and C output tensors must match");
  NVTE_CHECK(transposed_output_list.size() == input_list.size(),
             "Number of input and T output tensors must match");
  NVTE_CHECK(amax_list.size() == input_list.size(),
             "Number of input and AMAX tensors must match");
  NVTE_CHECK(scale_inv_list.size() == input_list.size(),
             "Number of input and scale_inv tensors must match");

  // Launch TE kernel
  nvte_multi_cast_transpose(input_list.size(),
                            input_list.data(),
                            scale_list.data(),
                            cast_output_list.data(),
                            transposed_output_list.data(),
                            amax_list.data(),
                            scale_inv_list.data(),
                            at::cuda::getCurrentCUDAStream());
}