transformer_engine.cpp 23.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

9
10
#include <atomic>
#include <climits>
11
#include <cstring>
12
#include <iostream>
13
#include <mutex>
14

Przemek Tredak's avatar
Przemek Tredak committed
15
#include "common.h"
16
#include "common/util/cuda_runtime.h"
17
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20

namespace transformer_engine {

21
size_t typeToNumBits(const DType type) {
22
23
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
24
25
}

26
27
28
29
size_t typeToSize(const DType type) {
  NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
  return typeToNumBits(type) / 8;
}
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

std::string to_string(const DType type) {
  switch (type) {
    case DType::kByte:
      return "Byte";
    case DType::kBFloat16:
      return "BFloat16";
    case DType::kFloat16:
      return "Float16";
    case DType::kFloat32:
      return "Float32";
    case DType::kFloat8E4M3:
      return "Float8E4M3";
    case DType::kFloat8E5M2:
      return "Float8E5M2";
    case DType::kFloat8E8M0:
      return "Float8E8M0";
47
48
    case DType::kFloat4E2M1:
      return "Float4E2M1";
49
50
    case DType::kInt16:
      return "Int16";
51
52
53
54
55
56
57
58
59
60
61
62
    case DType::kInt32:
      return "Int32";
    case DType::kInt64:
      return "Int64";
    default:
      return concat_strings("Invalid type ", static_cast<int>(type));
  }
}

std::string to_string(const NVTEScalingMode &mode) {
  switch (mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
63
      return "NVTE_DELAYED_TENSOR_SCALING";
64
    case NVTE_MXFP8_1D_SCALING:
65
      return "NVTE_MXFP8_1D_SCALING";
66
67
    case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
      return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
68
    case NVTE_INVALID_SCALING:
69
      return "NVTE_INVALID_SCALING";
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  }
  return "Invalid Scaling";
}

void CheckNoopTensor(const Tensor &t, const std::string &name) {
  if (t.data.dptr != nullptr) {
    NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
               ".");
    NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
               " noop. Expected kFloat32.");
  }
}

void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
  NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
  if (is_tensor_scaling(t.scaling_mode)) {
    // per-tensor scaling
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid columnwise_scale_inv shape (expected (1), got ",
                 t.columnwise_scale_inv.shape, ")");
    }
  } else {
97
98
    if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
        t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
99
100
101
      // Need (4, 128) alignment even for e8 scaling factor
      auto block_alignment = std::vector<size_t>{128ul, 4ul};
      size_t expected_x, expected_y, alignment;
102
103
      const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
      const size_t block_size_colwise = 32;
104
105
106
107
108
109
110

      if (t.has_data()) {
        alignment = block_alignment[0];
        expected_x =
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
        alignment = block_alignment[1];
        expected_y =
111
112
            DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
            alignment;
113
114
115
116
117
118
119
120
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        alignment = block_alignment[1];
        expected_x =
121
122
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
            alignment;
123
124
125
126
127
128
129
130
131
        alignment = block_alignment[0];
        expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
    }
  }
132
133
134
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
135
  const DType type = t.dtype();
136
137
  if (is_fp8_dtype(type)) {
    // FP8 input needs to have scale_inv
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
157
  } else {
158
159
160
161
162
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
163
  }
164
165
166
  NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

  CheckScaleTensorShape(t, name);
167
168
169
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
170
  const DType type = t.dtype();
171
  if (is_fp8_dtype(type)) {
172
    // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
173
    if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
      NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
                 to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
      NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
                 " (expected 1 entry, got shape=", t.amax.shape, ")");
    }
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
198
  } else {
199
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
200
201
    // Note: amax is supported for non-FP8 output as it can be fused into the computation
    //       and later used for quantization with no need to compute it separately
202
203
204
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
205
206
207
  }

  if (!allow_empty) {
208
    NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
209
  }
210
211

  CheckScaleTensorShape(t, name);
212
213
}

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
class TensorAllocator {
 public:
  static TensorAllocator &instance() {
    static TensorAllocator allocator;
    return allocator;
  }

  ~TensorAllocator() {}

  NVTETensor Allocate(NVTEScalingMode mode) {
    std::lock_guard<std::mutex> lock(mutex);
    if (!free_list.empty()) {
      uintptr_t index = free_list.back();
      NVTETensor ret = reinterpret_cast<NVTETensor>(index);
      free_list.pop_back();
      if (debug) {
        std::cout << "Allocated " << index
                  << " from free list. Free list size: " << free_list.size() << " and capacity "
                  << free_list.capacity() << std::endl;
      }
      // 1-based indexing
      memory[index - 1].scaling_mode = mode;
      return ret;
    }
    if (memory.size() < memory.capacity()) {
      memory.emplace_back();
      Tensor &t = memory.back();
      size = memory.size();
      // 1-based indexing
      uintptr_t index = memory.size();
      if (debug) {
        std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
                  << memory.capacity() << std::endl;
      }
      t.scaling_mode = mode;
      t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
      return reinterpret_cast<NVTETensor>(index);
    }
    NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
               MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
  }

  void Free(NVTETensor t) {
    std::lock_guard<std::mutex> lock(mutex);
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    if (index == 0) return;
    NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
    free_list.push_back(index);
    // Clean up
    memory[index - 1].clear();
    if (debug) {
      std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
                << free_list.capacity() << std::endl;
    }
  }

  void Free(NVTETensor *t, size_t N) {
    std::lock_guard<std::mutex> lock(mutex);
    for (size_t i = 0; i < N; ++i) {
      uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
      if (index == 0) continue;
      NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
      free_list.push_back(index);
      // Clean up
      memory[index - 1].clear();
    }
    if (debug) {
      std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
                << " and capacity " << free_list.capacity() << std::endl;
    }
  }

  Tensor *convertNVTETensor(NVTETensor t) {
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    // 1-based indexing to enable 0-initialization of NVTETensor
    // to be invalid tensor
    static_assert(nullptr == 0);
    if (index != 0 && index <= size) {
      return &(memory[index - 1]);
    }
    return nullptr;
  }

  void setDebug(bool debug) {
    std::lock_guard<std::mutex> lock(mutex);
    this->debug = debug;
  }

 private:
  TensorAllocator() {
    std::lock_guard<std::mutex> lock(mutex);
    memory.reserve(MAX_TENSOR_NUM);
  }

  std::mutex mutex;
  std::atomic<size_t> size;
  // Allocate at most 20 MB for tensors
  // Should be replaced by virtual memory allocation
  const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
  std::vector<uintptr_t> free_list;
  std::vector<Tensor> memory;
  bool debug = false;
};

Tensor *convertNVTETensor(const NVTETensor t) {
  return TensorAllocator::instance().convertNVTETensor(t);
}

Tensor *convertNVTETensorCheck(const NVTETensor t) {
  Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
  NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
  return ptr;
}

Przemek Tredak's avatar
Przemek Tredak committed
328
329
}  // namespace transformer_engine

330
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
331
  NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
Przemek Tredak's avatar
Przemek Tredak committed
332
333
334
335
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
336
337
338
339
340
  transformer_engine::TensorAllocator::instance().Free(tensor);
}

void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
  transformer_engine::TensorAllocator::instance().Free(tensors, N);
Przemek Tredak's avatar
Przemek Tredak committed
341
342
343
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
344
345
346
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return kNVTEFloat32;
  return static_cast<NVTEDType>(t->dtype());
Przemek Tredak's avatar
Przemek Tredak committed
347
348
}

349
350
351
352
353
354
355
356
357
358
359
360
361
362
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
  NVTEShape ret;
  if (ndim == 0) {
    ret.ndim = 0;
    return ret;
  }
  NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
             "Too many dims for NVTEShape (requested: ", ndim,
             ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
  std::copy(data, data + ndim, ret.data);
  ret.ndim = ndim;
  return ret;
}

Przemek Tredak's avatar
Przemek Tredak committed
363
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
364
365
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
366
    NVTE_ERROR("Invalid tensor");
367
368
  }

369
  // Determine tensor shape depending on tensor format
370
  const std::vector<size_t> &shape = t->shape();
371

372
  return nvte_make_shape(shape.data(), shape.size());
Przemek Tredak's avatar
Przemek Tredak committed
373
374
}

375
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
376
377
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
378
379
    NVTE_ERROR("Invalid tensor");
  }
380
381
  const std::vector<size_t> &shape = t->columnwise_data.shape;
  return nvte_make_shape(shape.data(), shape.size());
382
383
}

384
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
385
386

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
387
388
389
390
  const auto &shape = nvte_tensor_shape(tensor);
  NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
             " in a shape array with ", shape.ndim, " entries");
  return shape.data[dim];
391
392
393
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
394
  const auto &shape = nvte_tensor_shape(tensor);
395
  size_t numel = 1;
396
397
  for (size_t i = 0; i < shape.ndim; i++) {
    numel *= shape.data[i];
398
399
400
401
  }
  return numel;
}

402
403
404
405
406
407
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 8 * sizeof(float);
  return transformer_engine::typeToNumBits(t->dtype());
}

408
size_t nvte_tensor_element_size(const NVTETensor tensor) {
409
410
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return sizeof(float);
411
412
413
414
415
416
417
418
419
  NVTE_CHECK(!is_fp4_dtype(t->dtype()),
             "For FP4 type please use the nvte_tensor_element_size_bits.");
  return nvte_tensor_element_size_bits(tensor) / 8;
}

size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 0;
  return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
420
421
}

Przemek Tredak's avatar
Przemek Tredak committed
422
void *nvte_tensor_data(const NVTETensor tensor) {
423
424
425
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->data.dptr;
426
427
}

428
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
429
430
431
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_data.dptr;
432
433
}

434
float *nvte_tensor_amax(const NVTETensor tensor) {
435
436
437
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
438
             "Tensor's amax must have Float32 type!");
439
  return reinterpret_cast<float *>(t->amax.dptr);
440
441
442
}

float *nvte_tensor_scale(const NVTETensor tensor) {
443
444
445
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
446
             "Tensor's scale must have Float32 type!");
447
  return reinterpret_cast<float *>(t->scale.dptr);
448
449
450
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
451
452
453
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return reinterpret_cast<float *>(t->scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
454
}
cyanguwa's avatar
cyanguwa committed
455

456
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
457
458
459
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_scale_inv.dptr;
460
461
462
}

NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
463
464
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
465
466
    return nvte_make_shape(nullptr, 0);
  }
467
  return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
468
469
470
471
472
}

void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
                           const NVTEBasicTensor *param) {
  NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
473
474
  auto *t = transformer_engine::convertNVTETensor(*tensor);
  NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
475
476
  switch (param_name) {
    case kNVTERowwiseData:
477
      t->data = *param;
478
479
      break;
    case kNVTEColumnwiseData:
480
      t->columnwise_data = *param;
481
482
      break;
    case kNVTEScale:
483
      t->scale = *param;
484
485
      break;
    case kNVTEAmax:
486
      t->amax = *param;
487
488
      break;
    case kNVTERowwiseScaleInv:
489
      t->scale_inv = *param;
490
491
      break;
    case kNVTEColumnwiseScaleInv:
492
      t->columnwise_scale_inv = *param;
493
494
495
496
497
498
499
500
      break;
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
  if (tensor == nullptr) {
501
    return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
502
  }
503
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
  switch (param_name) {
    case kNVTERowwiseData:
      return t.data;
    case kNVTEColumnwiseData:
      return t.columnwise_data;
    case kNVTEScale:
      return t.scale;
    case kNVTEAmax:
      return t.amax;
    case kNVTERowwiseScaleInv:
      return t.scale_inv;
    case kNVTEColumnwiseScaleInv:
      return t.columnwise_scale_inv;
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
523
524
525
526
  if (tensor == nullptr) {
    return NVTE_DELAYED_TENSOR_SCALING;
  }
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
527
528
529
  return t.scaling_mode;
}

530
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
531
  for (int i = 0; i < pack->MAX_SIZE; i++) {
532
533
    pack->tensors[i] =
        transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
cyanguwa's avatar
cyanguwa committed
534
535
536
  }
}

537
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
538
  transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
cyanguwa's avatar
cyanguwa committed
539
}
540
541

void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
542
543
  if (tensor == nullptr) return;
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
544
545
  // Zero out tensor data if allocated
  if (t.data.dptr != nullptr) {
546
    const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
547
    NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
548
549
550
  }
  // Set amax to 0 if allocated
  if (t.amax.dptr != nullptr) {
551
    NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
552
553
  }
}
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590

NVTEQuantizationConfig nvte_create_quantization_config() {
  return new transformer_engine::QuantizationConfig;
}

void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, void *buf,
                                            size_t size_in_bytes, size_t *size_written) {
  // Write attribute size
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  *size_written = attr_size;

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");

  // Write to buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(buf, &config_.amax_epsilon, attr_size);
      break;
591
592
593
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(buf, &config_.noop_tensor, attr_size);
      break;
594
595
596
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
      break;
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, const void *buf,
                                            size_t size_in_bytes) {
  // Check attribute and buffer
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

  // Read from buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(&config_.amax_epsilon, buf, attr_size);
      break;
626
627
628
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(&config_.noop_tensor, buf, attr_size);
      break;
629
630
631
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
      break;
632
633
634
635
636
637
638
639
640
641
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
  if (config != nullptr) {
    delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  }
}
642
643
644
645
646
647
648
649
650
651

int nvte_is_non_tn_fp8_gemm_supported() {
  int deviceComputeCapability =
      transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());

  // Note: this is temporary restriction and should be lifted in the future.
  // (remove the note once it's done.)
  return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
         deviceComputeCapability >= 130;
}