transformer_engine.cpp 23.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

9
10
#include <atomic>
#include <climits>
11
#include <cstring>
12
#include <iostream>
13
#include <mutex>
14

Przemek Tredak's avatar
Przemek Tredak committed
15
#include "common.h"
16
#include "common/util/cuda_runtime.h"
17
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
18
19
20

namespace transformer_engine {

21
size_t typeToNumBits(const DType type) {
22
23
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
24
25
}

26
size_t typeToSize(const DType type) {
27
  #if FP4_TYPE_SUPPORTED
28
  NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
29
  #endif
30
31
  return typeToNumBits(type) / 8;
}
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

std::string to_string(const DType type) {
  switch (type) {
    case DType::kByte:
      return "Byte";
    case DType::kBFloat16:
      return "BFloat16";
    case DType::kFloat16:
      return "Float16";
    case DType::kFloat32:
      return "Float32";
    case DType::kFloat8E4M3:
      return "Float8E4M3";
    case DType::kFloat8E5M2:
      return "Float8E5M2";
    case DType::kFloat8E8M0:
      return "Float8E8M0";
49
    #if FP4_TYPE_SUPPORTED
50
51
    case DType::kFloat4E2M1:
      return "Float4E2M1";
52
    #endif
53
54
    case DType::kInt16:
      return "Int16";
55
56
57
58
59
60
61
62
63
64
65
66
    case DType::kInt32:
      return "Int32";
    case DType::kInt64:
      return "Int64";
    default:
      return concat_strings("Invalid type ", static_cast<int>(type));
  }
}

std::string to_string(const NVTEScalingMode &mode) {
  switch (mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
67
      return "NVTE_DELAYED_TENSOR_SCALING";
68
    case NVTE_MXFP8_1D_SCALING:
69
      return "NVTE_MXFP8_1D_SCALING";
70
71
    case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
      return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
72
    case NVTE_INVALID_SCALING:
73
      return "NVTE_INVALID_SCALING";
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  }
  return "Invalid Scaling";
}

void CheckNoopTensor(const Tensor &t, const std::string &name) {
  if (t.data.dptr != nullptr) {
    NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
               ".");
    NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
               " noop. Expected kFloat32.");
  }
}

void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
  NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
  if (is_tensor_scaling(t.scaling_mode)) {
    // per-tensor scaling
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid columnwise_scale_inv shape (expected (1), got ",
                 t.columnwise_scale_inv.shape, ")");
    }
  } else {
101
102
    if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
        t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
103
104
105
      // Need (4, 128) alignment even for e8 scaling factor
      auto block_alignment = std::vector<size_t>{128ul, 4ul};
      size_t expected_x, expected_y, alignment;
106
107
      const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
      const size_t block_size_colwise = 32;
108
109
110
111
112
113
114

      if (t.has_data()) {
        alignment = block_alignment[0];
        expected_x =
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
        alignment = block_alignment[1];
        expected_y =
115
116
            DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
            alignment;
117
118
119
120
121
122
123
124
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        alignment = block_alignment[1];
        expected_x =
125
126
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
            alignment;
127
128
129
130
131
132
133
134
135
        alignment = block_alignment[0];
        expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
    }
  }
136
137
138
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
139
  const DType type = t.dtype();
yuguo's avatar
yuguo committed
140
  if (is_fp8_dtype(type) || is_int8_dtype(type)) {
141
    // FP8 input needs to have scale_inv
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
161
  } else {
162
163
164
165
166
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
167
  }
168
169
170
  NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

  CheckScaleTensorShape(t, name);
171
172
173
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
174
  const DType type = t.dtype();
yuguo's avatar
yuguo committed
175
  if (is_fp8_dtype(type) || is_int8_dtype(type)) {
176
    // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
177
    if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
      NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
                 to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
      NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
                 " (expected 1 entry, got shape=", t.amax.shape, ")");
    }
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
202
  } else {
203
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
204
205
    // Note: amax is supported for non-FP8 output as it can be fused into the computation
    //       and later used for quantization with no need to compute it separately
206
207
208
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
209
210
211
  }

  if (!allow_empty) {
212
    NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
213
  }
214
215

  CheckScaleTensorShape(t, name);
216
217
}

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class TensorAllocator {
 public:
  static TensorAllocator &instance() {
    static TensorAllocator allocator;
    return allocator;
  }

  ~TensorAllocator() {}

  NVTETensor Allocate(NVTEScalingMode mode) {
    std::lock_guard<std::mutex> lock(mutex);
    if (!free_list.empty()) {
      uintptr_t index = free_list.back();
      NVTETensor ret = reinterpret_cast<NVTETensor>(index);
      free_list.pop_back();
      if (debug) {
        std::cout << "Allocated " << index
                  << " from free list. Free list size: " << free_list.size() << " and capacity "
                  << free_list.capacity() << std::endl;
      }
      // 1-based indexing
      memory[index - 1].scaling_mode = mode;
      return ret;
    }
    if (memory.size() < memory.capacity()) {
      memory.emplace_back();
      Tensor &t = memory.back();
      size = memory.size();
      // 1-based indexing
      uintptr_t index = memory.size();
      if (debug) {
        std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
                  << memory.capacity() << std::endl;
      }
      t.scaling_mode = mode;
      t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
      return reinterpret_cast<NVTETensor>(index);
    }
    NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
               MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
  }

  void Free(NVTETensor t) {
    std::lock_guard<std::mutex> lock(mutex);
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    if (index == 0) return;
    NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
    free_list.push_back(index);
    // Clean up
    memory[index - 1].clear();
    if (debug) {
      std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
                << free_list.capacity() << std::endl;
    }
  }

  void Free(NVTETensor *t, size_t N) {
    std::lock_guard<std::mutex> lock(mutex);
    for (size_t i = 0; i < N; ++i) {
      uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
      if (index == 0) continue;
      NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
      free_list.push_back(index);
      // Clean up
      memory[index - 1].clear();
    }
    if (debug) {
      std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
                << " and capacity " << free_list.capacity() << std::endl;
    }
  }

  Tensor *convertNVTETensor(NVTETensor t) {
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    // 1-based indexing to enable 0-initialization of NVTETensor
    // to be invalid tensor
    static_assert(nullptr == 0);
    if (index != 0 && index <= size) {
      return &(memory[index - 1]);
    }
    return nullptr;
  }

  void setDebug(bool debug) {
    std::lock_guard<std::mutex> lock(mutex);
    this->debug = debug;
  }

 private:
  TensorAllocator() {
    std::lock_guard<std::mutex> lock(mutex);
    memory.reserve(MAX_TENSOR_NUM);
  }

  std::mutex mutex;
  std::atomic<size_t> size;
  // Allocate at most 20 MB for tensors
  // Should be replaced by virtual memory allocation
  const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
  std::vector<uintptr_t> free_list;
  std::vector<Tensor> memory;
  bool debug = false;
};

Tensor *convertNVTETensor(const NVTETensor t) {
  return TensorAllocator::instance().convertNVTETensor(t);
}

Tensor *convertNVTETensorCheck(const NVTETensor t) {
  Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
  NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
  return ptr;
}

Przemek Tredak's avatar
Przemek Tredak committed
332
333
}  // namespace transformer_engine

334
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
335
  NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
Przemek Tredak's avatar
Przemek Tredak committed
336
337
338
339
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
340
341
342
343
344
  transformer_engine::TensorAllocator::instance().Free(tensor);
}

void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
  transformer_engine::TensorAllocator::instance().Free(tensors, N);
Przemek Tredak's avatar
Przemek Tredak committed
345
346
347
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
348
349
350
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return kNVTEFloat32;
  return static_cast<NVTEDType>(t->dtype());
Przemek Tredak's avatar
Przemek Tredak committed
351
352
}

353
354
355
356
357
358
359
360
361
362
363
364
365
366
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
  NVTEShape ret;
  if (ndim == 0) {
    ret.ndim = 0;
    return ret;
  }
  NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
             "Too many dims for NVTEShape (requested: ", ndim,
             ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
  std::copy(data, data + ndim, ret.data);
  ret.ndim = ndim;
  return ret;
}

Przemek Tredak's avatar
Przemek Tredak committed
367
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
368
369
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
370
    NVTE_ERROR("Invalid tensor");
371
372
  }

373
  // Determine tensor shape depending on tensor format
374
  const std::vector<size_t> &shape = t->shape();
375

376
  return nvte_make_shape(shape.data(), shape.size());
Przemek Tredak's avatar
Przemek Tredak committed
377
378
}

379
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
380
381
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
382
383
    NVTE_ERROR("Invalid tensor");
  }
384
385
  const std::vector<size_t> &shape = t->columnwise_data.shape;
  return nvte_make_shape(shape.data(), shape.size());
386
387
}

388
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
389
390

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
391
392
393
394
  const auto &shape = nvte_tensor_shape(tensor);
  NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
             " in a shape array with ", shape.ndim, " entries");
  return shape.data[dim];
395
396
397
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
398
  const auto &shape = nvte_tensor_shape(tensor);
399
  size_t numel = 1;
400
401
  for (size_t i = 0; i < shape.ndim; i++) {
    numel *= shape.data[i];
402
403
404
405
  }
  return numel;
}

406
407
408
409
410
411
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 8 * sizeof(float);
  return transformer_engine::typeToNumBits(t->dtype());
}

412
size_t nvte_tensor_element_size(const NVTETensor tensor) {
413
414
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return sizeof(float);
415
416
417
418
419
420
421
422
423
  NVTE_CHECK(!is_fp4_dtype(t->dtype()),
             "For FP4 type please use the nvte_tensor_element_size_bits.");
  return nvte_tensor_element_size_bits(tensor) / 8;
}

size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 0;
  return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
424
425
}

Przemek Tredak's avatar
Przemek Tredak committed
426
void *nvte_tensor_data(const NVTETensor tensor) {
427
428
429
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->data.dptr;
430
431
}

432
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
433
434
435
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_data.dptr;
436
437
}

438
float *nvte_tensor_amax(const NVTETensor tensor) {
439
440
441
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
442
             "Tensor's amax must have Float32 type!");
443
  return reinterpret_cast<float *>(t->amax.dptr);
444
445
446
}

float *nvte_tensor_scale(const NVTETensor tensor) {
447
448
449
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
450
             "Tensor's scale must have Float32 type!");
451
  return reinterpret_cast<float *>(t->scale.dptr);
452
453
454
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
455
456
457
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return reinterpret_cast<float *>(t->scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
458
}
cyanguwa's avatar
cyanguwa committed
459

460
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
461
462
463
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_scale_inv.dptr;
464
465
466
}

NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
467
468
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
469
470
    return nvte_make_shape(nullptr, 0);
  }
471
  return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
472
473
474
475
476
}

void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
                           const NVTEBasicTensor *param) {
  NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
477
478
  auto *t = transformer_engine::convertNVTETensor(*tensor);
  NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
479
480
  switch (param_name) {
    case kNVTERowwiseData:
481
      t->data = *param;
482
483
      break;
    case kNVTEColumnwiseData:
484
      t->columnwise_data = *param;
485
486
      break;
    case kNVTEScale:
487
      t->scale = *param;
488
489
      break;
    case kNVTEAmax:
490
      t->amax = *param;
491
492
      break;
    case kNVTERowwiseScaleInv:
493
      t->scale_inv = *param;
494
495
      break;
    case kNVTEColumnwiseScaleInv:
496
      t->columnwise_scale_inv = *param;
497
498
499
500
501
502
503
504
      break;
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
  if (tensor == nullptr) {
505
    return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
506
  }
507
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
  switch (param_name) {
    case kNVTERowwiseData:
      return t.data;
    case kNVTEColumnwiseData:
      return t.columnwise_data;
    case kNVTEScale:
      return t.scale;
    case kNVTEAmax:
      return t.amax;
    case kNVTERowwiseScaleInv:
      return t.scale_inv;
    case kNVTEColumnwiseScaleInv:
      return t.columnwise_scale_inv;
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
527
528
529
530
  if (tensor == nullptr) {
    return NVTE_DELAYED_TENSOR_SCALING;
  }
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
531
532
533
  return t.scaling_mode;
}

534
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
535
  for (int i = 0; i < pack->MAX_SIZE; i++) {
536
537
    pack->tensors[i] =
        transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
cyanguwa's avatar
cyanguwa committed
538
539
540
  }
}

541
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
542
  transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
cyanguwa's avatar
cyanguwa committed
543
}
544
545

void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
546
547
  if (tensor == nullptr) return;
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
548
549
  // Zero out tensor data if allocated
  if (t.data.dptr != nullptr) {
550
    const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
551
552
553
554
    cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream);
  }
  // Set amax to 0 if allocated
  if (t.amax.dptr != nullptr) {
555
    cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream);
556
557
  }
}
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594

NVTEQuantizationConfig nvte_create_quantization_config() {
  return new transformer_engine::QuantizationConfig;
}

void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, void *buf,
                                            size_t size_in_bytes, size_t *size_written) {
  // Write attribute size
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  *size_written = attr_size;

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");

  // Write to buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(buf, &config_.amax_epsilon, attr_size);
      break;
595
596
597
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(buf, &config_.noop_tensor, attr_size);
      break;
598
599
600
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
      break;
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, const void *buf,
                                            size_t size_in_bytes) {
  // Check attribute and buffer
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

  // Read from buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(&config_.amax_epsilon, buf, attr_size);
      break;
630
631
632
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(&config_.noop_tensor, buf, attr_size);
      break;
633
634
635
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
      break;
636
637
638
639
640
641
642
643
644
645
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
  if (config != nullptr) {
    delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  }
}
646
647

int nvte_is_non_tn_fp8_gemm_supported() {
yuguo's avatar
yuguo committed
648
649
650
#if USE_ROCM
  return true;
#else
651
652
653
654
655
656
657
  int deviceComputeCapability =
      transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());

  // Note: this is temporary restriction and should be lifted in the future.
  // (remove the note once it's done.)
  return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
         deviceComputeCapability >= 130;
yuguo's avatar
yuguo committed
658
#endif
659
}