transformer_engine.cpp 46.8 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

9
#include <algorithm>
10
11
#include <atomic>
#include <climits>
12
#include <cstring>
13
#include <iostream>
14
#include <mutex>
15
16
#include <optional>
#include <string>
17
#include <utility>
18
#include <vector>
19

Przemek Tredak's avatar
Przemek Tredak committed
20
#include "common.h"
21
#include "common/util/cuda_runtime.h"
22
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25

namespace transformer_engine {

26
size_t typeToNumBits(const DType type) {
27
28
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
29
30
}

31
32
33
34
size_t typeToSize(const DType type) {
  NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
  return typeToNumBits(type) / 8;
}
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

std::string to_string(const DType type) {
  switch (type) {
    case DType::kByte:
      return "Byte";
    case DType::kBFloat16:
      return "BFloat16";
    case DType::kFloat16:
      return "Float16";
    case DType::kFloat32:
      return "Float32";
    case DType::kFloat8E4M3:
      return "Float8E4M3";
    case DType::kFloat8E5M2:
      return "Float8E5M2";
    case DType::kFloat8E8M0:
      return "Float8E8M0";
52
53
    case DType::kFloat4E2M1:
      return "Float4E2M1";
54
55
    case DType::kInt16:
      return "Int16";
56
57
58
59
60
61
62
63
64
65
66
67
    case DType::kInt32:
      return "Int32";
    case DType::kInt64:
      return "Int64";
    default:
      return concat_strings("Invalid type ", static_cast<int>(type));
  }
}

std::string to_string(const NVTEScalingMode &mode) {
  switch (mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
68
      return "NVTE_DELAYED_TENSOR_SCALING";
69
    case NVTE_MXFP8_1D_SCALING:
70
      return "NVTE_MXFP8_1D_SCALING";
71
72
73
74
    case NVTE_BLOCK_SCALING_1D:
      return "NVTE_BLOCK_SCALING_1D";
    case NVTE_BLOCK_SCALING_2D:
      return "NVTE_BLOCK_SCALING_2D";
75
76
    case NVTE_NVFP4_1D_SCALING:
      return "NVTE_NVFP4_1D_SCALING";
77
    case NVTE_INVALID_SCALING:
78
      return "NVTE_INVALID_SCALING";
79
80
81
82
83
  }
  return "Invalid Scaling";
}

void CheckNoopTensor(const Tensor &t, const std::string &name) {
84
  if (t.data.has_data()) {
85
86
87
88
89
90
91
92
93
94
    NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
               ".");
    NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
               " noop. Expected kFloat32.");
  }
}

void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
  NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
  if (is_tensor_scaling(t.scaling_mode)) {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    if (is_fp8_dtype(t.dtype())) {
      // FP8 tensor with tensor scaling
      if (t.has_data()) {
        NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected 1 entry, got ", t.scale_inv.shape,
                   ")");
      }
      if (t.has_columnwise_data()) {
        NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
                   "\" has invalid columnwise_scale_inv shape (expected 1 entry, got ",
                   t.columnwise_scale_inv.shape, ")");
      }
    } else {
      // High-precision tensor
      if (t.has_data()) {
        NVTE_CHECK(t.scale_inv.numel() == 0, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected 0 entries, got ", t.scale_inv.shape,
                   ")");
      }
      if (t.has_columnwise_data()) {
        NVTE_CHECK(t.columnwise_scale_inv.numel() == 0, "Tensor \"", name,
                   "\" has invalid columnwise_scale_inv shape (expected 0 entries, got ",
                   t.columnwise_scale_inv.shape, ")");
      }
119
120
    }
  } else {
121
    if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
122
123
124
      // Need (4, 128) alignment even for e8 scaling factor
      auto block_alignment = std::vector<size_t>{128ul, 4ul};
      size_t expected_x, expected_y, alignment;
125
      const size_t block_size_rowwise = 32;
126
      const size_t block_size_colwise = 32;
127
128
129
130
131
132
133

      if (t.has_data()) {
        alignment = block_alignment[0];
        expected_x =
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
        alignment = block_alignment[1];
        expected_y =
134
135
            DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
            alignment;
136

137
138
139
140
141
142
143
144
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        alignment = block_alignment[1];
        expected_x =
145
146
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
            alignment;
147
148
        alignment = block_alignment[0];
        expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
149

150
151
152
153
154
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
      if (t.has_data()) {
        const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128);
        const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4);
        const auto &expected = std::vector<size_t>{expected_y, expected_x};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128);
        const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4);
        const auto &expected = std::vector<size_t>{expected_y, expected_x};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
172
173
    }
  }
174
175
176
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
177
  const DType type = t.dtype();
178
179
  if (is_fp8_dtype(type)) {
    // FP8 input needs to have scale_inv
180
    if (t.has_data()) {
181
      NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor input ", name,
182
183
184
185
186
187
188
189
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
190
      NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor input ", name,
191
192
193
194
195
196
197
198
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
199
200
201
202
  } else if (is_fp4_dtype(type)) {
    // TODO(ksivaman): Fix this to check for amaxes and other details.
    // For now only needed for swizzle.
    if (t.has_data()) {
203
      NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor input ", name,
204
205
206
207
208
209
210
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected DType::kFloat8E4M3, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
211
      NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor input ", name,
212
213
214
215
216
217
218
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
                 name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected DType::kFloat8E4M3, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
219
  } else {
220
221
222
223
    NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 input ", name);
    NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ", name);
    NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
               name);
224
  }
225
226
227
  NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

  CheckScaleTensorShape(t, name);
228
229
230
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
231
  const DType type = t.dtype();
232
  if (is_fp8_dtype(type)) {
233
    // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
234
    if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.has_data()) {
235
236
      NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
                 to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
237
      NVTE_CHECK(t.amax.numel() == 1, "Invalid shape of amax in output ", name,
238
239
240
                 " (expected 1 entry, got shape=", t.amax.shape, ")");
    }
    if (t.has_data()) {
241
      NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor output ", name,
242
243
244
245
246
247
248
249
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
250
      NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor output ", name,
251
252
253
254
255
256
257
258
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
259
260
261
  } else if (is_fp4_dtype(type)) {
    // FP4 output needs to have the scale_inv
    if (t.has_data()) {
262
      NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor output ", name,
263
264
265
266
267
268
269
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float8E4M3, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
270
      NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor output ", name,
271
272
273
274
275
276
277
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
                 name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float8E4M3, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
278
  } else {
279
280
281
282
    NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 output ", name);
    NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 output ", name);
    NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
               name);
283
284
285
  }

  if (!allow_empty) {
286
    NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
287
  }
288
289

  CheckScaleTensorShape(t, name);
290
291
}

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name) {
  NVTE_CHECK(t.num_tensors > 0, "Grouped tensor ", name, " has no tensors!");

  // Helper lambda to validate shape arrays
  // All three arrays are OPTIONAL:
  // - first_dims: empty if all tensors have same first dimension
  // - last_dims: empty if all tensors have same last dimension
  // - tensor_offsets: empty if all tensors have same shape (offsets are predictable)
  auto check_shape_array = [&](const SimpleTensor &arr, const char *arr_name) {
    if (arr.has_data()) {
      NVTE_CHECK(arr.shape.size() == 1, "Grouped tensor ", name, " ", arr_name, " must be 1D");
      NVTE_CHECK(arr.dtype == DType::kInt64, "Grouped tensor ", name, " ", arr_name,
                 " must have dtype Int64");
      NVTE_CHECK(arr.shape[0] == t.num_tensors, "Grouped tensor ", name, " ", arr_name, " size (",
                 arr.shape[0], ") must equal num_tensors (", t.num_tensors, ")");
    }
  };

  // Validate shape arrays (all optional)
  check_shape_array(t.first_dims, "first_dims");
  check_shape_array(t.last_dims, "last_dims");
  check_shape_array(t.tensor_offsets, "tensor_offsets");

  // tensor_offsets is required if any dimension varies
  // (i.e., required unless all_same_shape())
  if (!t.all_same_shape()) {
    NVTE_CHECK(
        t.tensor_offsets.dptr != nullptr, "Grouped tensor ", name,
        " must have tensor_offsets when any dimension varies (first_dims or last_dims is set)");
  }

  // Validate logical_shape
  NVTE_CHECK(t.logical_shape.ndim == 2, "Grouped tensor ", name, " logical_shape must be 2D");
  NVTE_CHECK(t.logical_shape.data[0] > 0 && t.logical_shape.data[1] > 0, "Grouped tensor ", name,
             " logical_shape must have positive dimensions");

  // Validate all data fields are 1D (flattened)
  if (t.has_data()) {
    NVTE_CHECK(t.data.shape.size() == 1, "Grouped tensor ", name, " data must be 1D");
  }
  if (t.has_columnwise_data()) {
    NVTE_CHECK(t.columnwise_data.shape.size() == 1, "Grouped tensor ", name,
               " columnwise_data must be 1D");
  }

  // Validate data size matches logical_shape
  size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1];
  if (t.has_data()) {
    NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (",
               t.data.numel(), ") must match logical_shape size (", expected_numel, ")");
  }
  if (t.has_columnwise_data()) {
    NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name,
               " columnwise_data size (", t.columnwise_data.numel(),
               ") must match logical_shape size (", expected_numel, ")");
  }
}

// Helper function to check scale_inv for both input and output
static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name, bool is_output) {
  const char *tensor_type = is_output ? "output" : "input";

  // Helper to check scale_inv for both rowwise and columnwise layouts
  auto check_scales = [&](DType expected_dtype) {
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.has_data(), tensor_type, " ", name,
                 " rowwise scale_inv must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == expected_dtype, tensor_type, " ", name,
                 " rowwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
                 ", got ", to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.has_data(), tensor_type, " ", name,
                 " columnwise scale_inv must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == expected_dtype, tensor_type, " ", name,
                 " columnwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
                 ", got ", to_string(t.columnwise_scale_inv.dtype), ")");
    }
  };

  // Determine expected dtype based on data type and scaling mode
  if (is_fp8_dtype(t.dtype()) && is_tensor_scaling(t.scaling_mode)) {
    check_scales(DType::kFloat32);
  } else if (is_mxfp8_scaling(t.scaling_mode)) {
    check_scales(DType::kFloat8E8M0);
  } else if (is_nvfp4_scaling(t.scaling_mode)) {
    check_scales(DType::kFloat8E4M3);
  } else {
    // Non-quantized types should not have scale/scale_inv
    NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv not supported for non-quantized ", tensor_type,
               " ", name);
    NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv not supported for non-quantized ",
               tensor_type, " ", name);
  }
}

void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name) {
  NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input grouped tensor ", name,
             " not allocated");
  CheckGroupedScaleInv(t, name, false);
  CheckGroupedTensorShapeArrays(t, name);
}

void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, bool allow_empty) {
  if (!allow_empty) {
    NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output grouped tensor ", name,
               " not allocated");
  }

  // Only perform dtype-specific validation if data is allocated
  if (t.has_data() || t.has_columnwise_data()) {
    // Amax validation for delayed scaling
    if (is_fp8_dtype(t.dtype()) && t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
      NVTE_CHECK(t.amax.has_data(), "Output ", name, " amax must be allocated");
      NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Output ", name, " amax must be Float32");
    }
    CheckGroupedScaleInv(t, name, true);
  }

  CheckGroupedTensorShapeArrays(t, name);
}

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
class TensorAllocator {
 public:
  static TensorAllocator &instance() {
    static TensorAllocator allocator;
    return allocator;
  }

  ~TensorAllocator() {}

  NVTETensor Allocate(NVTEScalingMode mode) {
    std::lock_guard<std::mutex> lock(mutex);
    if (!free_list.empty()) {
      uintptr_t index = free_list.back();
      NVTETensor ret = reinterpret_cast<NVTETensor>(index);
      free_list.pop_back();
      if (debug) {
        std::cout << "Allocated " << index
                  << " from free list. Free list size: " << free_list.size() << " and capacity "
                  << free_list.capacity() << std::endl;
      }
      // 1-based indexing
      memory[index - 1].scaling_mode = mode;
      return ret;
    }
    if (memory.size() < memory.capacity()) {
      memory.emplace_back();
      Tensor &t = memory.back();
      size = memory.size();
      // 1-based indexing
      uintptr_t index = memory.size();
      if (debug) {
        std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
                  << memory.capacity() << std::endl;
      }
      t.scaling_mode = mode;
      t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
      return reinterpret_cast<NVTETensor>(index);
    }
    NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
               MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
  }

  void Free(NVTETensor t) {
    std::lock_guard<std::mutex> lock(mutex);
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    if (index == 0) return;
    NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
    free_list.push_back(index);
    // Clean up
    memory[index - 1].clear();
    if (debug) {
      std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
                << free_list.capacity() << std::endl;
    }
  }

  void Free(NVTETensor *t, size_t N) {
    std::lock_guard<std::mutex> lock(mutex);
    for (size_t i = 0; i < N; ++i) {
      uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
      if (index == 0) continue;
      NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
      free_list.push_back(index);
      // Clean up
      memory[index - 1].clear();
    }
    if (debug) {
      std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
                << " and capacity " << free_list.capacity() << std::endl;
    }
  }

  Tensor *convertNVTETensor(NVTETensor t) {
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    // 1-based indexing to enable 0-initialization of NVTETensor
    // to be invalid tensor
    static_assert(nullptr == 0);
    if (index != 0 && index <= size) {
      return &(memory[index - 1]);
    }
    return nullptr;
  }

  void setDebug(bool debug) {
    std::lock_guard<std::mutex> lock(mutex);
    this->debug = debug;
  }

 private:
  TensorAllocator() {
    std::lock_guard<std::mutex> lock(mutex);
    memory.reserve(MAX_TENSOR_NUM);
  }

  std::mutex mutex;
  std::atomic<size_t> size;
  // Allocate at most 20 MB for tensors
  // Should be replaced by virtual memory allocation
  const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
  std::vector<uintptr_t> free_list;
  std::vector<Tensor> memory;
  bool debug = false;
};

Tensor *convertNVTETensor(const NVTETensor t) {
  return TensorAllocator::instance().convertNVTETensor(t);
}

Tensor *convertNVTETensorCheck(const NVTETensor t) {
  Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
  NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
  return ptr;
}

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
// GroupedTensor allocator - similar pattern to TensorAllocator
class GroupedTensorAllocator {
 public:
  static GroupedTensorAllocator &instance() {
    static GroupedTensorAllocator allocator;
    return allocator;
  }

  ~GroupedTensorAllocator() {}

  NVTEGroupedTensor Allocate(NVTEScalingMode mode, size_t num_tensors, NVTEShape logical_shape) {
    std::lock_guard<std::mutex> lock(mutex);
    if (!free_list.empty()) {
      uintptr_t index = free_list.back();
      NVTEGroupedTensor ret = reinterpret_cast<NVTEGroupedTensor>(index);
      free_list.pop_back();
      // 1-based indexing - fully reinitialize the tensor to avoid stale data
      memory[index - 1].scaling_mode = mode;
      memory[index - 1].num_tensors = num_tensors;
      memory[index - 1].logical_shape = logical_shape;
      memory[index - 1].nvte_tensor = ret;
      return ret;
    }
    if (memory.size() < memory.capacity()) {
      memory.emplace_back(mode, num_tensors);
      GroupedTensor &t = memory.back();
      size = memory.size();
      // 1-based indexing
      uintptr_t index = memory.size();
      t.logical_shape = logical_shape;
      t.nvte_tensor = reinterpret_cast<NVTEGroupedTensor>(index);
      return reinterpret_cast<NVTEGroupedTensor>(index);
    }
    NVTE_ERROR(
        "Cannot allocate a new NVTEGroupedTensor. Maximum number of grouped tensors reached: ",
        MAX_GROUPED_TENSOR_NUM, ". There is probably a memory leak in your application.");
  }

  void Free(NVTEGroupedTensor t) {
    std::lock_guard<std::mutex> lock(mutex);
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    if (index == 0) return;
    NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
    free_list.push_back(index);
    // Clean up
    memory[index - 1].clear();
  }

  GroupedTensor *convertNVTEGroupedTensor(NVTEGroupedTensor t) {
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    // 1-based indexing to enable 0-initialization of NVTEGroupedTensor
    // to be invalid tensor
    static_assert(nullptr == 0);
    if (index != 0 && index <= size) {
      return &(memory[index - 1]);
    }
    return nullptr;
  }

 private:
  GroupedTensorAllocator() {
    std::lock_guard<std::mutex> lock(mutex);
    memory.reserve(MAX_GROUPED_TENSOR_NUM);
  }

  std::mutex mutex;
  std::atomic<size_t> size;
  // Allocate at most 20 MB for grouped tensors
  const size_t MAX_GROUPED_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(GroupedTensor);
  std::vector<uintptr_t> free_list;
  std::vector<GroupedTensor> memory;
};

GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor t) {
  return GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
}

GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor t) {
  GroupedTensor *ptr = GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
  NVTE_CHECK(ptr != nullptr, "Invalid grouped tensor.");
  return ptr;
}

Przemek Tredak's avatar
Przemek Tredak committed
611
612
}  // namespace transformer_engine

613
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
614
  NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
Przemek Tredak's avatar
Przemek Tredak committed
615
616
617
618
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
619
620
621
622
623
  transformer_engine::TensorAllocator::instance().Free(tensor);
}

void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
  transformer_engine::TensorAllocator::instance().Free(tensors, N);
Przemek Tredak's avatar
Przemek Tredak committed
624
625
626
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
627
628
629
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return kNVTEFloat32;
  return static_cast<NVTEDType>(t->dtype());
Przemek Tredak's avatar
Przemek Tredak committed
630
631
}

632
633
634
635
636
637
638
639
640
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
  NVTEShape ret;
  if (ndim == 0) {
    ret.ndim = 0;
    return ret;
  }
  NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
             "Too many dims for NVTEShape (requested: ", ndim,
             ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
641
642
643
644
645
  if (data == nullptr) {
    std::fill(ret.data, ret.data + ndim, 0);
  } else {
    std::copy(data, data + ndim, ret.data);
  }
646
647
648
649
  ret.ndim = ndim;
  return ret;
}

Przemek Tredak's avatar
Przemek Tredak committed
650
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
651
652
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
653
    NVTE_ERROR("Invalid tensor");
654
655
  }

656
  // Determine tensor shape depending on tensor format
657
  const std::vector<size_t> &shape = t->shape();
658

659
  return nvte_make_shape(shape.data(), shape.size());
Przemek Tredak's avatar
Przemek Tredak committed
660
661
}

662
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
663
664
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
665
666
    NVTE_ERROR("Invalid tensor");
  }
667
668
  const std::vector<size_t> &shape = t->columnwise_data.shape;
  return nvte_make_shape(shape.data(), shape.size());
669
670
}

671
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
672
673

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
674
675
676
677
  const auto &shape = nvte_tensor_shape(tensor);
  NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
             " in a shape array with ", shape.ndim, " entries");
  return shape.data[dim];
678
679
680
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
681
  const auto &shape = nvte_tensor_shape(tensor);
682
  size_t numel = 1;
683
684
  for (size_t i = 0; i < shape.ndim; i++) {
    numel *= shape.data[i];
685
686
687
688
  }
  return numel;
}

689
690
691
692
693
694
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 8 * sizeof(float);
  return transformer_engine::typeToNumBits(t->dtype());
}

695
size_t nvte_tensor_element_size(const NVTETensor tensor) {
696
697
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return sizeof(float);
698
699
700
701
702
703
704
705
706
  NVTE_CHECK(!is_fp4_dtype(t->dtype()),
             "For FP4 type please use the nvte_tensor_element_size_bits.");
  return nvte_tensor_element_size_bits(tensor) / 8;
}

size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 0;
  return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
707
708
}

Przemek Tredak's avatar
Przemek Tredak committed
709
void *nvte_tensor_data(const NVTETensor tensor) {
710
711
712
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->data.dptr;
713
714
}

715
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
716
717
718
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_data.dptr;
719
720
}

721
float *nvte_tensor_amax(const NVTETensor tensor) {
722
723
724
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
725
             "Tensor's amax must have Float32 type!");
726
  return reinterpret_cast<float *>(t->amax.dptr);
727
728
729
}

float *nvte_tensor_scale(const NVTETensor tensor) {
730
731
732
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
733
             "Tensor's scale must have Float32 type!");
734
  return reinterpret_cast<float *>(t->scale.dptr);
735
736
737
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
738
739
740
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return reinterpret_cast<float *>(t->scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
741
}
cyanguwa's avatar
cyanguwa committed
742

743
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
744
745
746
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_scale_inv.dptr;
747
748
749
}

NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
750
751
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
752
    return nvte_make_shape(nullptr, 1);
753
  }
754
  return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
755
756
757
758
759
}

void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
                           const NVTEBasicTensor *param) {
  NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
760
761
  auto *t = transformer_engine::convertNVTETensor(*tensor);
  NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
762
763
  switch (param_name) {
    case kNVTERowwiseData:
764
      t->data = *param;
765
766
      break;
    case kNVTEColumnwiseData:
767
      t->columnwise_data = *param;
768
769
      break;
    case kNVTEScale:
770
      t->scale = *param;
771
772
      break;
    case kNVTEAmax:
773
      t->amax = *param;
774
775
      break;
    case kNVTERowwiseScaleInv:
776
      t->scale_inv = *param;
777
778
      break;
    case kNVTEColumnwiseScaleInv:
779
      t->columnwise_scale_inv = *param;
780
      break;
781
782
783
    case kNVTEColumnwiseAmax:
      t->columnwise_amax = *param;
      break;
784
    default:
785
786
      NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
                 "). Consider using nvte_set_tensor_param_v2 instead.");
787
788
789
790
791
  }
}

NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
  if (tensor == nullptr) {
792
    return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
793
  }
794
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
795
796
797
798
799
800
801
802
803
804
805
806
807
  switch (param_name) {
    case kNVTERowwiseData:
      return t.data;
    case kNVTEColumnwiseData:
      return t.columnwise_data;
    case kNVTEScale:
      return t.scale;
    case kNVTEAmax:
      return t.amax;
    case kNVTERowwiseScaleInv:
      return t.scale_inv;
    case kNVTEColumnwiseScaleInv:
      return t.columnwise_scale_inv;
808
809
    case kNVTEColumnwiseAmax:
      return t.columnwise_amax;
810
    default:
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
      NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
                 "). Consider using nvte_set_tensor_param_v2 instead.");
  }
}

void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
                              size_t size_in_bytes) {
  // Check attribute and buffer
  NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
             ")");
  NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
  auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
  const auto &attr_size = transformer_engine::Tensor::attr_sizes[param];
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for tensor parameter "
             "(parameter ",
             static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
             size_in_bytes, " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

  // Read from buffer
  switch (param) {
    case kNVTERowwiseData: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.data = *basic_tensor;
      break;
    }
    case kNVTEColumnwiseData: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.columnwise_data = *basic_tensor;
      break;
    }
    case kNVTEScale: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.scale = *basic_tensor;
      break;
    }
    case kNVTEAmax: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.amax = *basic_tensor;
      break;
    }
    case kNVTERowwiseScaleInv: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.scale_inv = *basic_tensor;
      break;
    }
    case kNVTEColumnwiseScaleInv: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.columnwise_scale_inv = *basic_tensor;
      break;
    }
    case kNVTEColumnwiseAmax: {
      const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
      t.columnwise_amax = *basic_tensor;
      break;
    }
    case kNVTEWithGEMMSwizzledScales:
      t.with_gemm_swizzled_scales = static_cast<bool>(*reinterpret_cast<const uint8_t *>(buf));
      break;
    default:
      NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
  }
}

void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf,
                              size_t size_in_bytes, size_t *size_written) {
  using namespace transformer_engine;

  // Check param
  NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
             ")");

  // Write attribute size if provided
  const auto &attr_size = Tensor::attr_sizes[param];
  if (size_written != nullptr) {
    *size_written = attr_size;
  }

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for tensor parameter "
             "(parameter ",
             static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
             size_in_bytes, " bytes)");

  // Get C++ tensor
  const Tensor *t = convertNVTETensor(tensor);
  std::optional<Tensor> dummy;
  if (t == nullptr) {
    // Make dummy tensor if provided tensor is invalid
    dummy.emplace();
    t = &(*dummy);
  }

  // Write to buffer
  switch (param) {
    case kNVTERowwiseData: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->data);
      break;
    }
    case kNVTEColumnwiseData: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_data);
      break;
    }
    case kNVTEScale: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->scale);
      break;
    }
    case kNVTEAmax: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->amax);
      break;
    }
    case kNVTERowwiseScaleInv: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->scale_inv);
      break;
    }
    case kNVTEColumnwiseScaleInv: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_scale_inv);
      break;
    }
    case kNVTEColumnwiseAmax: {
      NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
      *basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_amax);
      break;
    }
    case kNVTEWithGEMMSwizzledScales:
      *reinterpret_cast<uint8_t *>(buf) = static_cast<uint8_t>(t->with_gemm_swizzled_scales);
      break;
    default:
      NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
953
954
955
956
  }
}

NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
957
958
959
960
  if (tensor == nullptr) {
    return NVTE_DELAYED_TENSOR_SCALING;
  }
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
961
962
963
  return t.scaling_mode;
}

964
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
965
  for (int i = 0; i < pack->MAX_SIZE; i++) {
966
967
    pack->tensors[i] =
        transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
cyanguwa's avatar
cyanguwa committed
968
969
970
  }
}

971
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
972
  transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
cyanguwa's avatar
cyanguwa committed
973
}
974
975

void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
976
977
  if (tensor == nullptr) return;
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
978

979
980
  // Zero out tensor data if allocated
  if (t.data.dptr != nullptr) {
981
982
983
984
    const auto size = t.data.buffer_size_bytes();
    if (size > 0) {
      NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size, stream));
    }
985
  }
986
987

  // Zero out amax if allocated
988
  if (t.amax.dptr != nullptr) {
989
990
991
992
    const auto size = t.amax.buffer_size_bytes();
    if (size > 0) {
      NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, size, stream));
    }
993
994
  }
}
995
996
997
998
999
1000
1001
1002

NVTEQuantizationConfig nvte_create_quantization_config() {
  return new transformer_engine::QuantizationConfig;
}

void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, void *buf,
                                            size_t size_in_bytes, size_t *size_written) {
1003
1004
  using namespace transformer_engine;

1005
1006
1007
  // Write attribute size
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
1008
  const auto &attr_size = QuantizationConfig::attr_sizes[attr];
1009
1010
1011
  if (size_written != nullptr) {
    *size_written = attr_size;
  }
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");

1025
1026
1027
1028
1029
1030
  // bool size is implementation-dependent, so we explicitly specify
  // uint8_t in the user-facing API.
  auto bool_to_uint8 = [](bool in, void *out) {
    *reinterpret_cast<uint8_t *>(out) = static_cast<uint8_t>(in);
  };

1031
1032
  // Write to buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
1033
  const auto &config_ = *reinterpret_cast<const QuantizationConfig *>(config);
1034
1035
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
1036
      bool_to_uint8(config_.force_pow_2_scales, buf);
1037
1038
1039
1040
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(buf, &config_.amax_epsilon, attr_size);
      break;
1041
1042
1043
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(buf, &config_.noop_tensor, attr_size);
      break;
1044
1045
1046
1047
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: {
      // Deprecated
      const auto invalid = Float8BlockScaleTensorFormat::INVALID;
      std::memcpy(buf, &invalid, attr_size);
1048
      break;
1049
    }
1050
1051
1052
1053
    case kNVTEQuantizationConfigRNGState:
      std::memcpy(buf, &config_.rng_state, attr_size);
      break;
    case kNVTEQuantizationConfigNVFP42DQuantization:
1054
      bool_to_uint8(config_.nvfp4_2d_quantization, buf);
1055
1056
      break;
    case kNVTEQuantizationConfigStochasticRounding:
1057
      bool_to_uint8(config_.stochastic_rounding, buf);
1058
1059
      break;
    case kNVTEQuantizationConfigUseFastMath:
1060
      bool_to_uint8(config_.use_fast_math, buf);
1061
      break;
1062
1063
1064
1065
1066
1067
1068
1069
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, const void *buf,
                                            size_t size_in_bytes) {
1070
1071
  using namespace transformer_engine;

1072
1073
1074
  // Check attribute and buffer
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
1075
  const auto &attr_size = QuantizationConfig::attr_sizes[attr];
1076
1077
1078
1079
1080
1081
1082
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

1083
1084
1085
1086
1087
1088
  // bool size is implementation-dependent, so we explicitly specify
  // uint8_t in the user-facing API.
  auto uint8_to_bool = [](const void *in, bool &out) {
    out = static_cast<bool>(*reinterpret_cast<const uint8_t *>(in));
  };

1089
1090
  // Read from buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
1091
  auto &config_ = *reinterpret_cast<QuantizationConfig *>(config);
1092
1093
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
1094
      uint8_to_bool(buf, config_.force_pow_2_scales);
1095
1096
1097
1098
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(&config_.amax_epsilon, buf, attr_size);
      break;
1099
1100
1101
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(&config_.noop_tensor, buf, attr_size);
      break;
1102
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
1103
      // Deprecated
1104
      break;
1105
1106
1107
1108
    case kNVTEQuantizationConfigRNGState:
      std::memcpy(&config_.rng_state, buf, attr_size);
      break;
    case kNVTEQuantizationConfigNVFP42DQuantization:
1109
      uint8_to_bool(buf, config_.nvfp4_2d_quantization);
1110
1111
      break;
    case kNVTEQuantizationConfigStochasticRounding:
1112
      uint8_to_bool(buf, config_.stochastic_rounding);
1113
      break;
1114
    case kNVTEQuantizationConfigUseFastMath:
1115
      uint8_to_bool(buf, config_.use_fast_math);
1116
      break;
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
  if (config != nullptr) {
    delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  }
}
1127
1128

int nvte_is_non_tn_fp8_gemm_supported() {
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
  int num_devices = transformer_engine::cuda::num_devices();
  static std::vector<int> cache(num_devices, -1);
  static std::vector<std::once_flag> flags(num_devices);
  int device_id = transformer_engine::cuda::current_device();
  std::call_once(flags[device_id], [&]() {
    int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
    // Note: this is temporary restriction and should be lifted in the future.
    // (remove the note once it's done.)
    cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
                       deviceComputeCapability >= 130;
  });
  return cache[device_id];
1141
}
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213

// Grouped Tensor C API implementations
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
                                             NVTEShape logical_shape) {
  NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0");
  NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
  NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0,
             "Logical shape must have positive dimensions");
  NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate(
      scaling_mode, num_tensors, logical_shape);
  return ret;
}

void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) {
  transformer_engine::GroupedTensorAllocator::instance().Free(tensor);
}

void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
                                   const NVTEBasicTensor *param) {
  NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL.");
  auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor);
  NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated.");
  NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL.");

  switch (param_name) {
    case kNVTEGroupedRowwiseData:
      t->data = *param;
      break;
    case kNVTEGroupedColumnwiseData:
      t->columnwise_data = *param;
      break;
    case kNVTEGroupedScale:
      t->scale = *param;
      break;
    case kNVTEGroupedAmax:
      t->amax = *param;
      break;
    case kNVTEGroupedRowwiseScaleInv:
      t->scale_inv = *param;
      break;
    case kNVTEGroupedColumnwiseScaleInv:
      t->columnwise_scale_inv = *param;
      break;
    case kNVTEGroupedColumnwiseAmax:
      t->columnwise_amax = *param;
      break;
    case kNVTEGroupedFirstDims:
      t->first_dims = *param;
      // Validate it's Int64
      NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64,
                 "first_dims must have dtype Int64");
      break;
    case kNVTEGroupedLastDims:
      t->last_dims = *param;
      // Validate it's Int64
      NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64,
                 "last_dims must have dtype Int64");
      break;
    case kNVTEGroupedTensorOffsets:
      t->tensor_offsets = *param;
      // Validate it's Int64
      NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64,
                 "tensor_offsets must have dtype Int64");
      break;
    default:
      NVTE_ERROR("Unknown grouped tensor parameter!");
  }
}

NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
                                              NVTEGroupedTensorParam param_name) {
  if (tensor == nullptr) {
1214
    return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
  }
  const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);

  switch (param_name) {
    case kNVTEGroupedRowwiseData:
      return t.data;
    case kNVTEGroupedColumnwiseData:
      return t.columnwise_data;
    case kNVTEGroupedScale:
      return t.scale;
    case kNVTEGroupedAmax:
      return t.amax;
    case kNVTEGroupedRowwiseScaleInv:
      return t.scale_inv;
    case kNVTEGroupedColumnwiseScaleInv:
      return t.columnwise_scale_inv;
    case kNVTEGroupedColumnwiseAmax:
      return t.columnwise_amax;
    case kNVTEGroupedFirstDims:
      return t.first_dims;
    case kNVTEGroupedLastDims:
      return t.last_dims;
    case kNVTEGroupedTensorOffsets:
      return t.tensor_offsets;
    default:
      NVTE_ERROR("Unknown grouped tensor parameter!");
  }
}

size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor) {
  auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
  if (t == nullptr) return 0;
  return t->num_tensors;
}

NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor) {
  auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
  if (t == nullptr) return kNVTEFloat32;
  return static_cast<NVTEDType>(t->dtype());
}

NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) {
  if (tensor == nullptr) {
    return NVTE_DELAYED_TENSOR_SCALING;
  }
  const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
  return t.scaling_mode;
}

NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) {
  if (tensor == nullptr) {
1266
    return nvte_make_shape(nullptr, 1);
1267
1268
1269
1270
  }
  const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
  return t.logical_shape;
}