transformer_engine.cpp 27.8 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

9
10
#include <atomic>
#include <climits>
11
#include <cstring>
12
#include <iostream>
13
#include <mutex>
14
#include <utility>
15

Przemek Tredak's avatar
Przemek Tredak committed
16
#include "common.h"
17
#include "common/util/cuda_runtime.h"
18
#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21

namespace transformer_engine {

22
size_t typeToNumBits(const DType type) {
23
24
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
25
26
}

27
size_t typeToSize(const DType type) {
28
  #if FP4_TYPE_SUPPORTED
29
  NVTE_CHECK(type != DType::kFloat4E2M1, "typeToSize() Does not support FP4 data type.");
30
  #endif
31
32
  return typeToNumBits(type) / 8;
}
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

std::string to_string(const DType type) {
  switch (type) {
    case DType::kByte:
      return "Byte";
    case DType::kBFloat16:
      return "BFloat16";
    case DType::kFloat16:
      return "Float16";
    case DType::kFloat32:
      return "Float32";
    case DType::kFloat8E4M3:
      return "Float8E4M3";
    case DType::kFloat8E5M2:
      return "Float8E5M2";
    case DType::kFloat8E8M0:
      return "Float8E8M0";
50
    #if FP4_TYPE_SUPPORTED
51
52
    case DType::kFloat4E2M1:
      return "Float4E2M1";
53
    #endif
54
55
    case DType::kInt16:
      return "Int16";
56
57
58
59
60
61
62
63
64
65
66
67
    case DType::kInt32:
      return "Int32";
    case DType::kInt64:
      return "Int64";
    default:
      return concat_strings("Invalid type ", static_cast<int>(type));
  }
}

std::string to_string(const NVTEScalingMode &mode) {
  switch (mode) {
    case NVTE_DELAYED_TENSOR_SCALING:
68
      return "NVTE_DELAYED_TENSOR_SCALING";
69
    case NVTE_MXFP8_1D_SCALING:
70
      return "NVTE_MXFP8_1D_SCALING";
71
72
73
74
    case NVTE_BLOCK_SCALING_1D:
      return "NVTE_BLOCK_SCALING_1D";
    case NVTE_BLOCK_SCALING_2D:
      return "NVTE_BLOCK_SCALING_2D";
75
76
    case NVTE_NVFP4_1D_SCALING:
      return "NVTE_NVFP4_1D_SCALING";
77
    case NVTE_INVALID_SCALING:
78
      return "NVTE_INVALID_SCALING";
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  }
  return "Invalid Scaling";
}

void CheckNoopTensor(const Tensor &t, const std::string &name) {
  if (t.data.dptr != nullptr) {
    NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
               ".");
    NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
               " noop. Expected kFloat32.");
  }
}

void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
  NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
  if (is_tensor_scaling(t.scaling_mode)) {
    // per-tensor scaling
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
                 "\" has invalid columnwise_scale_inv shape (expected (1), got ",
                 t.columnwise_scale_inv.shape, ")");
    }
  } else {
106
    if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
107
108
109
      // Need (4, 128) alignment even for e8 scaling factor
      auto block_alignment = std::vector<size_t>{128ul, 4ul};
      size_t expected_x, expected_y, alignment;
110
      const size_t block_size_rowwise = 32;
111
      const size_t block_size_colwise = 32;
112
113
114
115
116
117
118

      if (t.has_data()) {
        alignment = block_alignment[0];
        expected_x =
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(1)), alignment) * alignment;
        alignment = block_alignment[1];
        expected_y =
119
120
            DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
            alignment;
121

122
123
124
125
126
127
128
129
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        alignment = block_alignment[1];
        expected_x =
130
131
            DIVUP(DIVUP(t.flat_first_dim(), static_cast<size_t>(block_size_colwise)), alignment) *
            alignment;
132
133
        alignment = block_alignment[0];
        expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
134

135
136
137
138
139
        const auto &expected = std::vector<size_t>{expected_x, expected_y};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
      if (t.has_data()) {
        const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128);
        const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4);
        const auto &expected = std::vector<size_t>{expected_y, expected_x};
        NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
                   "\" has invalid scale_inv shape (expected ", expected, ", got ",
                   t.scale_inv.shape, ")");
      }
      if (t.has_columnwise_data()) {
        const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128);
        const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4);
        const auto &expected = std::vector<size_t>{expected_y, expected_x};
        NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
                   "\"  has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
                   t.columnwise_scale_inv.shape, ")");
      }
157
158
    }
  }
159
160
161
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
162
  const DType type = t.dtype();
yuguo's avatar
yuguo committed
163
  if (is_fp8_dtype(type) || is_int8_dtype(type)) {
164
    // FP8 input needs to have scale_inv
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor input ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Byte, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
  } else if (is_fp4_dtype(type)) {
    // TODO(ksivaman): Fix this to check for amaxes and other details.
    // For now only needed for swizzle.
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected DType::kFloat8E4M3, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
                 name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected DType::kFloat8E4M3, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
204
  } else {
205
206
207
208
209
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
210
  }
211
212
213
  NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");

  CheckScaleTensorShape(t, name);
214
215
216
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
217
  const DType type = t.dtype();
yuguo's avatar
yuguo committed
218
  if (is_fp8_dtype(type) || is_int8_dtype(type)) {
219
    // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
220
    if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
      NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
                 to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
      NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
                 " (expected 1 entry, got shape=", t.amax.shape, ")");
    }
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
                     t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
                 "FP8 scaling factor output ", name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float32 or Float8E8M0, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
  } else if (is_fp4_dtype(type)) {
    // FP4 output needs to have the scale_inv
    if (t.has_data()) {
      NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
                 "_scale_inverse must be allocated");
      NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
                 "_scale_inverse has invalid dtype "
                 "(expected Float8E4M3, got ",
                 to_string(t.scale_inv.dtype), ")");
    }
    if (t.has_columnwise_data()) {
      NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
                 "_columnwise_scale_inverse must be allocated");
      NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
                 name,
                 "_columnwise_scale_inverse has invalid dtype "
                 "(expected Float8E4M3, got ",
                 to_string(t.columnwise_scale_inv.dtype), ")");
    }
264
  } else {
265
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
266
267
    // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
    // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
268
269
270
    NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
    NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input ", name);
271
272
273
  }

  if (!allow_empty) {
274
    NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!");
275
  }
276
277

  CheckScaleTensorShape(t, name);
278
279
}

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
class TensorAllocator {
 public:
  static TensorAllocator &instance() {
    static TensorAllocator allocator;
    return allocator;
  }

  ~TensorAllocator() {}

  NVTETensor Allocate(NVTEScalingMode mode) {
    std::lock_guard<std::mutex> lock(mutex);
    if (!free_list.empty()) {
      uintptr_t index = free_list.back();
      NVTETensor ret = reinterpret_cast<NVTETensor>(index);
      free_list.pop_back();
      if (debug) {
        std::cout << "Allocated " << index
                  << " from free list. Free list size: " << free_list.size() << " and capacity "
                  << free_list.capacity() << std::endl;
      }
      // 1-based indexing
      memory[index - 1].scaling_mode = mode;
      return ret;
    }
    if (memory.size() < memory.capacity()) {
      memory.emplace_back();
      Tensor &t = memory.back();
      size = memory.size();
      // 1-based indexing
      uintptr_t index = memory.size();
      if (debug) {
        std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
                  << memory.capacity() << std::endl;
      }
      t.scaling_mode = mode;
      t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
      return reinterpret_cast<NVTETensor>(index);
    }
    NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
               MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
  }

  void Free(NVTETensor t) {
    std::lock_guard<std::mutex> lock(mutex);
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    if (index == 0) return;
    NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
    free_list.push_back(index);
    // Clean up
    memory[index - 1].clear();
    if (debug) {
      std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
                << free_list.capacity() << std::endl;
    }
  }

  void Free(NVTETensor *t, size_t N) {
    std::lock_guard<std::mutex> lock(mutex);
    for (size_t i = 0; i < N; ++i) {
      uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
      if (index == 0) continue;
      NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
      free_list.push_back(index);
      // Clean up
      memory[index - 1].clear();
    }
    if (debug) {
      std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
                << " and capacity " << free_list.capacity() << std::endl;
    }
  }

  Tensor *convertNVTETensor(NVTETensor t) {
    uintptr_t index = reinterpret_cast<uintptr_t>(t);
    // 1-based indexing to enable 0-initialization of NVTETensor
    // to be invalid tensor
    static_assert(nullptr == 0);
    if (index != 0 && index <= size) {
      return &(memory[index - 1]);
    }
    return nullptr;
  }

  void setDebug(bool debug) {
    std::lock_guard<std::mutex> lock(mutex);
    this->debug = debug;
  }

 private:
  TensorAllocator() {
    std::lock_guard<std::mutex> lock(mutex);
    memory.reserve(MAX_TENSOR_NUM);
  }

  std::mutex mutex;
  std::atomic<size_t> size;
  // Allocate at most 20 MB for tensors
  // Should be replaced by virtual memory allocation
  const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
  std::vector<uintptr_t> free_list;
  std::vector<Tensor> memory;
  bool debug = false;
};

Tensor *convertNVTETensor(const NVTETensor t) {
  return TensorAllocator::instance().convertNVTETensor(t);
}

Tensor *convertNVTETensorCheck(const NVTETensor t) {
  Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
  NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
  return ptr;
}

Przemek Tredak's avatar
Przemek Tredak committed
394
395
}  // namespace transformer_engine

396
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
397
  NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
Przemek Tredak's avatar
Przemek Tredak committed
398
399
400
401
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
402
403
404
405
406
  transformer_engine::TensorAllocator::instance().Free(tensor);
}

void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
  transformer_engine::TensorAllocator::instance().Free(tensors, N);
Przemek Tredak's avatar
Przemek Tredak committed
407
408
409
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
410
411
412
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return kNVTEFloat32;
  return static_cast<NVTEDType>(t->dtype());
Przemek Tredak's avatar
Przemek Tredak committed
413
414
}

415
416
417
418
419
420
// Because of a HIP compiler bug, we need to disable optimizations here
// when compiling for AMD GPUs while test_float8blockwisetensor.py.
// Todo: remove this once the HIP compiler bug is fixed.
#ifdef __HIP_PLATFORM_AMD__
#pragma clang optimize off
#endif
421
422
423
424
425
426
427
428
429
430
431
432
433
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
  NVTEShape ret;
  if (ndim == 0) {
    ret.ndim = 0;
    return ret;
  }
  NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
             "Too many dims for NVTEShape (requested: ", ndim,
             ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
  std::copy(data, data + ndim, ret.data);
  ret.ndim = ndim;
  return ret;
}
434
435
436
437
438
439
// Because of a HIP compiler bug, we need to disable optimizations here
// when compiling for AMD GPUs while test_float8blockwisetensor.py.
// Todo: remove this once the HIP compiler bug is fixed.
#ifdef __HIP_PLATFORM_AMD__
#pragma clang optimize on
#endif
440

Przemek Tredak's avatar
Przemek Tredak committed
441
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
442
443
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
444
    NVTE_ERROR("Invalid tensor");
445
446
  }

447
  // Determine tensor shape depending on tensor format
448
  const std::vector<size_t> &shape = t->shape();
449

450
  return nvte_make_shape(shape.data(), shape.size());
Przemek Tredak's avatar
Przemek Tredak committed
451
452
}

453
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
454
455
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
456
457
    NVTE_ERROR("Invalid tensor");
  }
458
459
  const std::vector<size_t> &shape = t->columnwise_data.shape;
  return nvte_make_shape(shape.data(), shape.size());
460
461
}

462
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
463
464

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
465
466
467
468
  const auto &shape = nvte_tensor_shape(tensor);
  NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim,
             " in a shape array with ", shape.ndim, " entries");
  return shape.data[dim];
469
470
471
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
472
  const auto &shape = nvte_tensor_shape(tensor);
473
  size_t numel = 1;
474
475
  for (size_t i = 0; i < shape.ndim; i++) {
    numel *= shape.data[i];
476
477
478
479
  }
  return numel;
}

480
481
482
483
484
485
size_t nvte_tensor_element_size_bits(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 8 * sizeof(float);
  return transformer_engine::typeToNumBits(t->dtype());
}

486
size_t nvte_tensor_element_size(const NVTETensor tensor) {
487
488
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return sizeof(float);
489
490
491
492
493
494
495
496
497
  NVTE_CHECK(!is_fp4_dtype(t->dtype()),
             "For FP4 type please use the nvte_tensor_element_size_bits.");
  return nvte_tensor_element_size_bits(tensor) / 8;
}

size_t nvte_tensor_size_bytes(const NVTETensor tensor) {
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return 0;
  return (nvte_tensor_numel(tensor) * nvte_tensor_element_size_bits(tensor)) / 8;
498
499
}

Przemek Tredak's avatar
Przemek Tredak committed
500
void *nvte_tensor_data(const NVTETensor tensor) {
501
502
503
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->data.dptr;
504
505
}

506
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
507
508
509
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_data.dptr;
510
511
}

512
float *nvte_tensor_amax(const NVTETensor tensor) {
513
514
515
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
516
             "Tensor's amax must have Float32 type!");
517
  return reinterpret_cast<float *>(t->amax.dptr);
518
519
520
}

float *nvte_tensor_scale(const NVTETensor tensor) {
521
522
523
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
524
             "Tensor's scale must have Float32 type!");
525
  return reinterpret_cast<float *>(t->scale.dptr);
526
527
528
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
529
530
531
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return reinterpret_cast<float *>(t->scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
532
}
cyanguwa's avatar
cyanguwa committed
533

534
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
535
536
537
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) return nullptr;
  return t->columnwise_scale_inv.dptr;
538
539
540
}

NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
541
542
  auto *t = transformer_engine::convertNVTETensor(tensor);
  if (t == nullptr) {
543
544
    return nvte_make_shape(nullptr, 0);
  }
545
  return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
546
547
548
549
550
}

void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
                           const NVTEBasicTensor *param) {
  NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
551
552
  auto *t = transformer_engine::convertNVTETensor(*tensor);
  NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
553
554
  switch (param_name) {
    case kNVTERowwiseData:
555
      t->data = *param;
556
557
      break;
    case kNVTEColumnwiseData:
558
      t->columnwise_data = *param;
559
560
      break;
    case kNVTEScale:
561
      t->scale = *param;
562
563
      break;
    case kNVTEAmax:
564
      t->amax = *param;
565
566
      break;
    case kNVTERowwiseScaleInv:
567
      t->scale_inv = *param;
568
569
      break;
    case kNVTEColumnwiseScaleInv:
570
      t->columnwise_scale_inv = *param;
571
      break;
572
573
574
    case kNVTEColumnwiseAmax:
      t->columnwise_amax = *param;
      break;
575
576
577
578
579
580
581
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
  if (tensor == nullptr) {
582
    return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
583
  }
584
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
585
586
587
588
589
590
591
592
593
594
595
596
597
  switch (param_name) {
    case kNVTERowwiseData:
      return t.data;
    case kNVTEColumnwiseData:
      return t.columnwise_data;
    case kNVTEScale:
      return t.scale;
    case kNVTEAmax:
      return t.amax;
    case kNVTERowwiseScaleInv:
      return t.scale_inv;
    case kNVTEColumnwiseScaleInv:
      return t.columnwise_scale_inv;
598
599
    case kNVTEColumnwiseAmax:
      return t.columnwise_amax;
600
601
602
603
604
605
    default:
      NVTE_ERROR("Unknown tensor parameter!");
  }
}

NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
606
607
608
609
  if (tensor == nullptr) {
    return NVTE_DELAYED_TENSOR_SCALING;
  }
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
610
611
612
  return t.scaling_mode;
}

613
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
614
  for (int i = 0; i < pack->MAX_SIZE; i++) {
615
616
    pack->tensors[i] =
        transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
cyanguwa's avatar
cyanguwa committed
617
618
619
  }
}

620
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
621
  transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
cyanguwa's avatar
cyanguwa committed
622
}
623
624

void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
625
626
  if (tensor == nullptr) return;
  const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
627
628
  // Zero out tensor data if allocated
  if (t.data.dptr != nullptr) {
629
    const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
630
    NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
631
632
633
  }
  // Set amax to 0 if allocated
  if (t.amax.dptr != nullptr) {
634
    NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
635
636
  }
}
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

NVTEQuantizationConfig nvte_create_quantization_config() {
  return new transformer_engine::QuantizationConfig;
}

void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, void *buf,
                                            size_t size_in_bytes, size_t *size_written) {
  // Write attribute size
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  *size_written = attr_size;

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");

  // Write to buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(buf, &config_.amax_epsilon, attr_size);
      break;
674
675
676
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(buf, &config_.noop_tensor, attr_size);
      break;
677
678
679
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
      break;
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
                                            NVTEQuantizationConfigAttribute attr, const void *buf,
                                            size_t size_in_bytes) {
  // Check attribute and buffer
  NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
             "Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for quantization config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

  // Read from buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
  auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  switch (attr) {
    case kNVTEQuantizationConfigForcePow2Scales:
      std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
      break;
    case kNVTEQuantizationConfigAmaxEpsilon:
      std::memcpy(&config_.amax_epsilon, buf, attr_size);
      break;
709
710
711
    case kNVTEQuantizationConfigNoopTensor:
      std::memcpy(&config_.noop_tensor, buf, attr_size);
      break;
712
713
714
    case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
      std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
      break;
715
716
717
718
719
720
721
722
723
    case kNVTEQuantizationConfigRNGState:
      std::memcpy(&config_.rng_state, buf, attr_size);
      break;
    case kNVTEQuantizationConfigNVFP42DQuantization:
      std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
      break;
    case kNVTEQuantizationConfigStochasticRounding:
      std::memcpy(&config_.stochastic_rounding, buf, attr_size);
      break;
724
725
726
727
728
729
730
731
732
733
    default:
      NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
  if (config != nullptr) {
    delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
  }
}
734
735

int nvte_is_non_tn_fp8_gemm_supported() {
yuguo's avatar
yuguo committed
736
737
738
#if USE_ROCM
  return true;
#else
739
740
741
742
743
744
745
  int deviceComputeCapability =
      transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());

  // Note: this is temporary restriction and should be lifted in the future.
  // (remove the note once it's done.)
  return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
         deviceComputeCapability >= 130;
yuguo's avatar
yuguo committed
746
#endif
747
}